diff options
author | 2017-09-13 20:40:32 -0700 | |
---|---|---|
committer | 2017-09-13 20:45:06 -0700 | |
commit | de0bc082f153e36f9919c2cac8fc1063fe3c9186 (patch) | |
tree | ee533587844e238296306998a114d5f8ab28e539 /tensorflow/core | |
parent | ad1069e5900157a3a2a782a3f2a0aa62b0ebab19 (diff) |
Making sure that the src_incarnation field on the ParsedKey for the Send and Recv's is set correctly.
PiperOrigin-RevId: 168635306
Diffstat (limited to 'tensorflow/core')
4 files changed, 76 insertions, 32 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 4aeacc6d61..d886a02305 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -579,6 +579,15 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, done(s); return; } + int64 src_incarnation, target_incarnation; + s = parent_->GetDeviceIncarnation(source_device, &src_incarnation); + s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation)); + if (!s.ok()) { + delete frame; + delete exec_args; + done(s); + return; + } // The ProcFLR sends the arguments to the function from the source_device to // the target_device. So here we receive those arguments. Similarly, when the @@ -586,10 +595,11 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, // to the source_device (caller) so that the ProcFLR can receive them later. std::vector<Tensor>* remote_args = new std::vector<Tensor>; ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( - source_device, target_device, "arg_", args.size(), rendez_args, - rendezvous, remote_args, - [frame, remote_args, item, source_device, target_device, rendezvous, - rendez_args, rets, done, exec_args](const Status& status) { + source_device, target_device, "arg_", src_incarnation, args.size(), + rendez_args, rendezvous, remote_args, + [frame, remote_args, item, source_device, target_device, + target_incarnation, rendezvous, rendez_args, rets, done, + exec_args](const Status& status) { Status s = status; s = frame->SetArgs(*remote_args); if (!s.ok()) { @@ -600,9 +610,9 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, return; } item->exec->RunAsync( - *exec_args, - [item, frame, rets, done, source_device, target_device, rendezvous, - rendez_args, remote_args, exec_args](const Status& status) { + *exec_args, [item, frame, rets, done, source_device, target_device, + target_incarnation, rendezvous, rendez_args, + remote_args, exec_args](const Status& status) { item->Unref(); Status s = status; if (s.ok()) { @@ -616,8 +626,8 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, return; } s = ProcessFunctionLibraryRuntime::SendTensors( - target_device, source_device, "ret_", *rets, rendez_args, - rendezvous); + target_device, source_device, "ret_", target_incarnation, + *rets, rendez_args, rendezvous); delete remote_args; delete exec_args; done(s); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index c39bab2348..26ae6907bc 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -71,13 +71,14 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( /* static */ Status ProcessFunctionLibraryRuntime::SendTensors( const string& source_device, const string& target_device, - const string& key_prefix, gtl::ArraySlice<Tensor> tensors_to_send, - const Rendezvous::Args& args, Rendezvous* rendezvous) { + const string& key_prefix, int64 src_incarnation, + gtl::ArraySlice<Tensor> tensors_to_send, const Rendezvous::Args& args, + Rendezvous* rendezvous) { std::vector<string> keys; for (int i = 0; i < tensors_to_send.size(); ++i) { string name = strings::StrCat(key_prefix, i); - string key = Rendezvous::CreateKey(source_device, i, target_device, name, - FrameAndIter(0, 0)); + string key = Rendezvous::CreateKey(source_device, src_incarnation, + target_device, name, FrameAndIter(0, 0)); keys.push_back(key); } TF_RETURN_IF_ERROR( @@ -88,14 +89,14 @@ Status ProcessFunctionLibraryRuntime::SendTensors( /* static */ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( const string& source_device, const string& target_device, - const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args, - Rendezvous* rendezvous, std::vector<Tensor>* received_tensors, - const StatusCallback& done) { + const string& key_prefix, int64 src_incarnation, int64 num_tensors, + const Rendezvous::Args& args, Rendezvous* rendezvous, + std::vector<Tensor>* received_tensors, const StatusCallback& done) { std::vector<string> keys; for (int64 i = 0; i < num_tensors; ++i) { string name = strings::StrCat(key_prefix, i); - string key = Rendezvous::CreateKey(source_device, i, target_device, name, - FrameAndIter(0, 0)); + string key = Rendezvous::CreateKey(source_device, src_incarnation, + target_device, name, FrameAndIter(0, 0)); keys.push_back(key); } RecvOutputsFromRendezvousAsync( @@ -103,6 +104,16 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( [done](const Status& status) { done(status); }); } +Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation( + const string& device_name, int64* incarnation) { + FunctionLibraryRuntime* flr = GetFLR(device_name); + if (flr == nullptr) { + return errors::InvalidArgument("Device name: ", device_name, " not found"); + } + *incarnation = flr->device()->attributes().incarnation(); + return Status::OK(); +} + Status ProcessFunctionLibraryRuntime::GetDeviceContext( const string& device_name, DeviceContext** device_context) { *device_context = nullptr; @@ -224,17 +235,25 @@ void ProcessFunctionLibraryRuntime::Run( done(s); return; } + int64 src_incarnation, target_incarnation; + s = GetDeviceIncarnation(source_device, &src_incarnation); + s.Update(GetDeviceIncarnation(target_device, &target_incarnation)); + if (!s.ok()) { + done(s); + return; + } + // Send the args over to the target device. - s = SendTensors(source_device, target_device, "arg_", args, rendez_args, - rendezvous); + s = SendTensors(source_device, target_device, "arg_", src_incarnation, args, + rendez_args, rendezvous); if (!s.ok()) { done(s); return; } std::vector<Tensor>* remote_rets = new std::vector<Tensor>; flr->Run(opts, handle, args, remote_rets, - [source_device, target_device, rendezvous, remote_rets, rets, done, - rendez_args](const Status& status) { + [source_device, target_device, target_incarnation, rendezvous, + remote_rets, rets, done, rendez_args](const Status& status) { if (!status.ok()) { delete remote_rets; done(status); @@ -244,8 +263,8 @@ void ProcessFunctionLibraryRuntime::Run( delete remote_rets; // Now receive the return values from the target. ReceiveTensorsAsync(target_device, source_device, "ret_", - num_returns, rendez_args, rendezvous, rets, - done); + target_incarnation, num_returns, rendez_args, + rendezvous, rets, done); }); } else { done(errors::Internal("Could not find device")); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 2e97bae4b4..7ff1d5c7a7 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -51,7 +51,7 @@ class ProcessFunctionLibraryRuntime { // Method doesn't block. static Status SendTensors(const string& source_device, const string& target_device, - const string& key_prefix, + const string& key_prefix, int64 src_incarnation, gtl::ArraySlice<Tensor> tensors_to_send, const Rendezvous::Args& args, Rendezvous* rendezvous); @@ -62,18 +62,19 @@ class ProcessFunctionLibraryRuntime { // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the // keys to be retrieved. Method doesn't block and calls `done` when // `num_tensors` are fetched. - static void ReceiveTensorsAsync(const string& source_device, - const string& target_device, - const string& key_prefix, int64 num_tensors, - const Rendezvous::Args& args, - Rendezvous* rendezvous, - std::vector<Tensor>* received_tensors, - const StatusCallback& done); + static void ReceiveTensorsAsync( + const string& source_device, const string& target_device, + const string& key_prefix, int64 src_incarnation, int64 num_tensors, + const Rendezvous::Args& args, Rendezvous* rendezvous, + std::vector<Tensor>* received_tensors, const StatusCallback& done); static const char kDefaultFLRDevice[]; // Returns the FunctionLibraryRuntime for the corresponding device_name. FunctionLibraryRuntime* GetFLR(const string& device_name); + // Returns the device incarnation for the given device_name. + Status GetDeviceIncarnation(const string& device_name, int64* incarnation); + // For a given canonicalized key signature of the function instantiated // on device `device_name` and a `local_handle`, creates a handle and returns // that value. Use core/common_runtime/framework/function.h::Canonicalize diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index fdbab46f54..50379a52c4 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -120,6 +121,19 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { EXPECT_EQ("/job:a/replica:0/task:0/cpu:1", target); } +TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) { + Init({}); + int64 incarnation; + TF_EXPECT_OK(proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:1", + &incarnation)); + // Incarnation is a random number other than 0. + EXPECT_NE(incarnation, 0); + Status s = proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:2", + &incarnation); + EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); + rendezvous_->Unref(); +} + TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { Init({test::function::XTimesTwo()}); FunctionLibraryRuntime::Options opts; |