aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/eager/eager_service_impl.cc')
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc36
1 files changed, 33 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 466e779fab..916c8720f0 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -81,10 +81,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
CreateContextResponse* response) {
- //make sure env_ , env_->rendezvous_mgr available
+ // make sure env_ , env_->rendezvous_mgr available
if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
- return tensorflow::errors::Internal("invalid eager env_ or env_->rendezvous_mgr.");
- }
+ return tensorflow::errors::Internal(
+ "invalid eager env_ or env_->rendezvous_mgr.");
+ }
std::vector<tensorflow::Device*> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
@@ -266,6 +267,35 @@ Status EagerServiceImpl::RegisterFunction(
return context->Context()->AddFunctionDef(request->function_def());
}
+Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
+ SendTensorResponse* response) {
+ ServerContext* context = nullptr;
+ TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
+ core::ScopedUnref context_unref(context);
+
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
+ for (const auto& tensor_proto : request->tensors()) {
+ Tensor tensor;
+ if (!tensor.FromProto(tensor_proto)) {
+ return errors::InvalidArgument("Unable to parse tensor proto");
+ }
+
+ TensorHandle* tensor_handle =
+ new TensorHandle(tensor, nullptr, nullptr, nullptr);
+
+ TensorHandle* copied_handle = nullptr;
+ TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(),
+ request->device_name().c_str(),
+ &copied_handle));
+ tensors.push_back(copied_handle);
+ tensor_handle->Unref();
+ }
+
+ context->AddOperationOutputs(tensors, request->op_id());
+
+ return Status::OK();
+}
+
tensorflow::Status EagerServiceImpl::GetServerContext(
uint64 context_id, ServerContext** server_context) {
mutex_lock l(contexts_mu_);