aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-11-01 16:15:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-01 16:20:28 -0700
commit117bcd9cb5f3e55ce1fcc09a0bb4963c32bad8ce (patch)
tree3d9a1d16b4bb78057e1a613ab01858d0d0fbcf22 /tensorflow/core
parent70698a168669e0335872ce9248a6c496328d7871 (diff)
Adding support for local device names for ProcessFLR. Now one can specify a remote target as /device:CPU:0 or /device:GPU:0 etc.
PiperOrigin-RevId: 174252575
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/function.cc6
-rw-r--r--tensorflow/core/common_runtime/function_test.cc5
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc26
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h3
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc26
5 files changed, 45 insertions, 21 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 10356fc789..23d0f331c5 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -411,7 +411,11 @@ bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) {
if (device_ == nullptr) return true;
string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
if (target.empty()) return true;
- return target == device_->name();
+ Device* target_device;
+ if (!device_mgr_->LookupDevice(target, &target_device).ok()) {
+ return false;
+ }
+ return target_device == device_;
}
AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) {
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index b77a8f50c4..d183bf7c97 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -939,9 +939,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Handle handle;
- TF_CHECK_OK(Instantiate(
- flr0_, "FindDevice",
- {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
+ TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {{"_target", "/device:CPU:1"}},
+ &handle));
Tensor y;
FunctionLibraryRuntime::Options opts;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index c4114ff873..142ff2339b 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -30,15 +30,15 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
DistributedFunctionLibraryRuntime* parent)
- : lib_def_(lib_def), parent_(parent) {
+ : device_mgr_(device_mgr), lib_def_(lib_def), parent_(parent) {
if (device_mgr == nullptr) {
- flr_map_[kDefaultFLRDevice] =
+ flr_map_[nullptr] =
NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
lib_def, optimizer_options, this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
- flr_map_[d->name()] =
+ flr_map_[d] =
NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version,
lib_def, optimizer_options, this);
}
@@ -50,15 +50,15 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
DistributedFunctionLibraryRuntime* parent)
- : lib_def_(lib_def), parent_(parent) {
+ : device_mgr_(device_mgr), lib_def_(lib_def), parent_(parent) {
if (device_mgr == nullptr) {
- flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime(
+ flr_map_[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
std::move(custom_kernel_creator), this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
- flr_map_[d->name()] = NewFunctionLibraryRuntime(
+ flr_map_[d] = NewFunctionLibraryRuntime(
device_mgr, env, d, graph_def_version, lib_def, optimizer_options,
custom_kernel_creator, this);
}
@@ -163,17 +163,19 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
const string& device_name) {
- string clean_device_name;
+ Device* device = nullptr;
if (device_name != kDefaultFLRDevice) {
- clean_device_name = DeviceNameUtils::CanonicalizeDeviceName(device_name);
- } else {
- clean_device_name = device_name;
+ if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
+ LOG(ERROR) << "Could not find device: " << device_name;
+ return nullptr;
+ }
}
- if (flr_map_.find(clean_device_name) == flr_map_.end()) {
+ const auto& iter = flr_map_.find(device);
+ if (iter == flr_map_.end()) {
LOG(ERROR) << "Could not find device: " << device_name;
return nullptr;
}
- return flr_map_[clean_device_name].get();
+ return iter->second.get();
}
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 85717739d0..a267bc3601 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -153,12 +153,13 @@ class ProcessFunctionLibraryRuntime {
: target_device(target_device), local_handle(local_handle) {}
};
+ const DeviceMgr* const device_mgr_;
const FunctionLibraryDefinition* lib_def_;
// Holds all the function invocations here.
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
GUARDED_BY(mu_);
std::vector<FunctionData> function_data_ GUARDED_BY(mu_);
- std::unordered_map<string, std::unique_ptr<FunctionLibraryRuntime>> flr_map_;
+ std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_;
DistributedFunctionLibraryRuntime* const parent_;
};
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index cb416603be..6bc8f980c7 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -92,12 +92,32 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
IntraProcessRendezvous* rendezvous_;
};
+TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
+ FunctionDefLibrary proto;
+ std::unique_ptr<FunctionLibraryDefinition> lib_def(
+ new FunctionLibraryDefinition(OpRegistry::Global(), proto));
+ OptimizerOptions opts;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
+ new ProcessFunctionLibraryRuntime(
+ nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION,
+ lib_def.get(), opts, nullptr /* cluster_flr */));
+ FunctionLibraryRuntime* flr =
+ proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+ EXPECT_NE(flr, nullptr);
+}
+
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
Init({});
FunctionLibraryRuntime* flr =
proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[0]);
+ flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
+ EXPECT_NE(flr, nullptr);
+ EXPECT_EQ(flr->device(), devices_[0]);
+ flr = proc_flr_->GetFLR("/device:CPU:0");
+ EXPECT_NE(flr, nullptr);
+ EXPECT_EQ(flr->device(), devices_[0]);
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1");
EXPECT_NE(flr, nullptr);
EXPECT_EQ(flr->device(), devices_[1]);
@@ -213,13 +233,11 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(Run("FindDevice", opts,
- {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
TensorShape({})));
- TF_CHECK_OK(Run("FindDevice", opts,
- {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));