aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/eager/eager_service_impl.cc')
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc40
1 files changed, 37 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 5a26d5bf48..916c8720f0 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -63,10 +63,10 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
}
*num_retvals += iter->second.i();
} else if (!output_arg.type_list_attr().empty()) {
- auto iter = attrs.find(output_arg.number_attr());
+ auto iter = attrs.find(output_arg.type_list_attr());
if (iter == attrs.end()) {
- return errors::InvalidArgument("Unable to find number_attr ",
- output_arg.number_attr(),
+ return errors::InvalidArgument("Unable to find type_list_attr ",
+ output_arg.type_list_attr(),
" for Op: ", op_name);
}
*num_retvals += iter->second.list().type_size();
@@ -81,6 +81,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
CreateContextResponse* response) {
+ // make sure env_ , env_->rendezvous_mgr available
+ if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
+ return tensorflow::errors::Internal(
+ "invalid eager env_ or env_->rendezvous_mgr.");
+ }
std::vector<tensorflow::Device*> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
@@ -262,6 +267,35 @@ Status EagerServiceImpl::RegisterFunction(
return context->Context()->AddFunctionDef(request->function_def());
}
+Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
+ SendTensorResponse* response) {
+ ServerContext* context = nullptr;
+ TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
+ core::ScopedUnref context_unref(context);
+
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
+ for (const auto& tensor_proto : request->tensors()) {
+ Tensor tensor;
+ if (!tensor.FromProto(tensor_proto)) {
+ return errors::InvalidArgument("Unable to parse tensor proto");
+ }
+
+ TensorHandle* tensor_handle =
+ new TensorHandle(tensor, nullptr, nullptr, nullptr);
+
+ TensorHandle* copied_handle = nullptr;
+ TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(),
+ request->device_name().c_str(),
+ &copied_handle));
+ tensors.push_back(copied_handle);
+ tensor_handle->Unref();
+ }
+
+ context->AddOperationOutputs(tensors, request->op_id());
+
+ return Status::OK();
+}
+
tensorflow::Status EagerServiceImpl::GetServerContext(
uint64 context_id, ServerContext** server_context) {
mutex_lock l(contexts_mu_);