aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-09-13 20:40:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-13 20:45:06 -0700
commitde0bc082f153e36f9919c2cac8fc1063fe3c9186 (patch)
treeee533587844e238296306998a114d5f8ab28e539 /tensorflow/core
parentad1069e5900157a3a2a782a3f2a0aa62b0ebab19 (diff)
Making sure that the src_incarnation field on the ParsedKey for the Send and Recv's is set correctly.
PiperOrigin-RevId: 168635306
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/function.cc28
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc49
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h17
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc14
4 files changed, 76 insertions, 32 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 4aeacc6d61..d886a02305 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -579,6 +579,15 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
done(s);
return;
}
+ int64 src_incarnation, target_incarnation;
+ s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
+ s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
+ if (!s.ok()) {
+ delete frame;
+ delete exec_args;
+ done(s);
+ return;
+ }
// The ProcFLR sends the arguments to the function from the source_device to
// the target_device. So here we receive those arguments. Similarly, when the
@@ -586,10 +595,11 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
// to the source_device (caller) so that the ProcFLR can receive them later.
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
- source_device, target_device, "arg_", args.size(), rendez_args,
- rendezvous, remote_args,
- [frame, remote_args, item, source_device, target_device, rendezvous,
- rendez_args, rets, done, exec_args](const Status& status) {
+ source_device, target_device, "arg_", src_incarnation, args.size(),
+ rendez_args, rendezvous, remote_args,
+ [frame, remote_args, item, source_device, target_device,
+ target_incarnation, rendezvous, rendez_args, rets, done,
+ exec_args](const Status& status) {
Status s = status;
s = frame->SetArgs(*remote_args);
if (!s.ok()) {
@@ -600,9 +610,9 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
item->exec->RunAsync(
- *exec_args,
- [item, frame, rets, done, source_device, target_device, rendezvous,
- rendez_args, remote_args, exec_args](const Status& status) {
+ *exec_args, [item, frame, rets, done, source_device, target_device,
+ target_incarnation, rendezvous, rendez_args,
+ remote_args, exec_args](const Status& status) {
item->Unref();
Status s = status;
if (s.ok()) {
@@ -616,8 +626,8 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
s = ProcessFunctionLibraryRuntime::SendTensors(
- target_device, source_device, "ret_", *rets, rendez_args,
- rendezvous);
+ target_device, source_device, "ret_", target_incarnation,
+ *rets, rendez_args, rendezvous);
delete remote_args;
delete exec_args;
done(s);
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index c39bab2348..26ae6907bc 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -71,13 +71,14 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
/* static */
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
- const string& key_prefix, gtl::ArraySlice<Tensor> tensors_to_send,
- const Rendezvous::Args& args, Rendezvous* rendezvous) {
+ const string& key_prefix, int64 src_incarnation,
+ gtl::ArraySlice<Tensor> tensors_to_send, const Rendezvous::Args& args,
+ Rendezvous* rendezvous) {
std::vector<string> keys;
for (int i = 0; i < tensors_to_send.size(); ++i) {
string name = strings::StrCat(key_prefix, i);
- string key = Rendezvous::CreateKey(source_device, i, target_device, name,
- FrameAndIter(0, 0));
+ string key = Rendezvous::CreateKey(source_device, src_incarnation,
+ target_device, name, FrameAndIter(0, 0));
keys.push_back(key);
}
TF_RETURN_IF_ERROR(
@@ -88,14 +89,14 @@ Status ProcessFunctionLibraryRuntime::SendTensors(
/* static */
void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& source_device, const string& target_device,
- const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args,
- Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
- const StatusCallback& done) {
+ const string& key_prefix, int64 src_incarnation, int64 num_tensors,
+ const Rendezvous::Args& args, Rendezvous* rendezvous,
+ std::vector<Tensor>* received_tensors, const StatusCallback& done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);
- string key = Rendezvous::CreateKey(source_device, i, target_device, name,
- FrameAndIter(0, 0));
+ string key = Rendezvous::CreateKey(source_device, src_incarnation,
+ target_device, name, FrameAndIter(0, 0));
keys.push_back(key);
}
RecvOutputsFromRendezvousAsync(
@@ -103,6 +104,16 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
[done](const Status& status) { done(status); });
}
+Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
+ const string& device_name, int64* incarnation) {
+ FunctionLibraryRuntime* flr = GetFLR(device_name);
+ if (flr == nullptr) {
+ return errors::InvalidArgument("Device name: ", device_name, " not found");
+ }
+ *incarnation = flr->device()->attributes().incarnation();
+ return Status::OK();
+}
+
Status ProcessFunctionLibraryRuntime::GetDeviceContext(
const string& device_name, DeviceContext** device_context) {
*device_context = nullptr;
@@ -224,17 +235,25 @@ void ProcessFunctionLibraryRuntime::Run(
done(s);
return;
}
+ int64 src_incarnation, target_incarnation;
+ s = GetDeviceIncarnation(source_device, &src_incarnation);
+ s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+
// Send the args over to the target device.
- s = SendTensors(source_device, target_device, "arg_", args, rendez_args,
- rendezvous);
+ s = SendTensors(source_device, target_device, "arg_", src_incarnation, args,
+ rendez_args, rendezvous);
if (!s.ok()) {
done(s);
return;
}
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
flr->Run(opts, handle, args, remote_rets,
- [source_device, target_device, rendezvous, remote_rets, rets, done,
- rendez_args](const Status& status) {
+ [source_device, target_device, target_incarnation, rendezvous,
+ remote_rets, rets, done, rendez_args](const Status& status) {
if (!status.ok()) {
delete remote_rets;
done(status);
@@ -244,8 +263,8 @@ void ProcessFunctionLibraryRuntime::Run(
delete remote_rets;
// Now receive the return values from the target.
ReceiveTensorsAsync(target_device, source_device, "ret_",
- num_returns, rendez_args, rendezvous, rets,
- done);
+ target_incarnation, num_returns, rendez_args,
+ rendezvous, rets, done);
});
} else {
done(errors::Internal("Could not find device"));
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 2e97bae4b4..7ff1d5c7a7 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -51,7 +51,7 @@ class ProcessFunctionLibraryRuntime {
// Method doesn't block.
static Status SendTensors(const string& source_device,
const string& target_device,
- const string& key_prefix,
+ const string& key_prefix, int64 src_incarnation,
gtl::ArraySlice<Tensor> tensors_to_send,
const Rendezvous::Args& args,
Rendezvous* rendezvous);
@@ -62,18 +62,19 @@ class ProcessFunctionLibraryRuntime {
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
// keys to be retrieved. Method doesn't block and calls `done` when
// `num_tensors` are fetched.
- static void ReceiveTensorsAsync(const string& source_device,
- const string& target_device,
- const string& key_prefix, int64 num_tensors,
- const Rendezvous::Args& args,
- Rendezvous* rendezvous,
- std::vector<Tensor>* received_tensors,
- const StatusCallback& done);
+ static void ReceiveTensorsAsync(
+ const string& source_device, const string& target_device,
+ const string& key_prefix, int64 src_incarnation, int64 num_tensors,
+ const Rendezvous::Args& args, Rendezvous* rendezvous,
+ std::vector<Tensor>* received_tensors, const StatusCallback& done);
static const char kDefaultFLRDevice[];
// Returns the FunctionLibraryRuntime for the corresponding device_name.
FunctionLibraryRuntime* GetFLR(const string& device_name);
+ // Returns the device incarnation for the given device_name.
+ Status GetDeviceIncarnation(const string& device_name, int64* incarnation);
+
// For a given canonicalized key signature of the function instantiated
// on device `device_name` and a `local_handle`, creates a handle and returns
// that value. Use core/common_runtime/framework/function.h::Canonicalize
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index fdbab46f54..50379a52c4 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -120,6 +121,19 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
EXPECT_EQ("/job:a/replica:0/task:0/cpu:1", target);
}
+TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) {
+ Init({});
+ int64 incarnation;
+ TF_EXPECT_OK(proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:1",
+ &incarnation));
+ // Incarnation is a random number other than 0.
+ EXPECT_NE(incarnation, 0);
+ Status s = proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:2",
+ &incarnation);
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ rendezvous_->Unref();
+}
+
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
Init({test::function::XTimesTwo()});
FunctionLibraryRuntime::Options opts;