diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/eager/eager_service_impl.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/eager/eager_service_impl.cc | 40 |
1 files changed, 37 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 5a26d5bf48..916c8720f0 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -63,10 +63,10 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, } *num_retvals += iter->second.i(); } else if (!output_arg.type_list_attr().empty()) { - auto iter = attrs.find(output_arg.number_attr()); + auto iter = attrs.find(output_arg.type_list_attr()); if (iter == attrs.end()) { - return errors::InvalidArgument("Unable to find number_attr ", - output_arg.number_attr(), + return errors::InvalidArgument("Unable to find type_list_attr ", + output_arg.type_list_attr(), " for Op: ", op_name); } *num_retvals += iter->second.list().type_size(); @@ -81,6 +81,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, CreateContextResponse* response) { + // make sure env_ , env_->rendezvous_mgr available + if (env_ == nullptr || env_->rendezvous_mgr == nullptr) { + return tensorflow::errors::Internal( + "invalid eager env_ or env_->rendezvous_mgr."); + } std::vector<tensorflow::Device*> devices; TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( @@ -262,6 +267,35 @@ Status EagerServiceImpl::RegisterFunction( return context->Context()->AddFunctionDef(request->function_def()); } +Status EagerServiceImpl::SendTensor(const SendTensorRequest* request, + SendTensorResponse* response) { + ServerContext* context = nullptr; + TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); + core::ScopedUnref context_unref(context); + + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors; + for (const auto& tensor_proto : request->tensors()) { + Tensor tensor; + if (!tensor.FromProto(tensor_proto)) { + return errors::InvalidArgument("Unable to parse tensor proto"); + } + + TensorHandle* tensor_handle = + new TensorHandle(tensor, nullptr, nullptr, nullptr); + + TensorHandle* copied_handle = nullptr; + TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(), + request->device_name().c_str(), + &copied_handle)); + tensors.push_back(copied_handle); + tensor_handle->Unref(); + } + + context->AddOperationOutputs(tensors, request->op_id()); + + return Status::OK(); +} + tensorflow::Status EagerServiceImpl::GetServerContext( uint64 context_id, ServerContext** server_context) { mutex_lock l(contexts_mu_); |