diff options
author | 2017-11-01 16:15:20 -0700 | |
---|---|---|
committer | 2017-11-01 16:20:28 -0700 | |
commit | 117bcd9cb5f3e55ce1fcc09a0bb4963c32bad8ce (patch) | |
tree | 3d9a1d16b4bb78057e1a613ab01858d0d0fbcf22 /tensorflow/core | |
parent | 70698a168669e0335872ce9248a6c496328d7871 (diff) |
Adding support for local device names for ProcessFLR. Now one can specify a remote target as /device:CPU:0 or /device:GPU:0 etc.
PiperOrigin-RevId: 174252575
Diffstat (limited to 'tensorflow/core')
5 files changed, 45 insertions, 21 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 10356fc789..23d0f331c5 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -411,7 +411,11 @@ bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) { if (device_ == nullptr) return true; string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); if (target.empty()) return true; - return target == device_->name(); + Device* target_device; + if (!device_mgr_->LookupDevice(target, &target_device).ok()) { + return false; + } + return target_device == device_; } AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index b77a8f50c4..d183bf7c97 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -939,9 +939,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { Init({test::function::FindDevice()}); FunctionLibraryRuntime::Handle handle; - TF_CHECK_OK(Instantiate( - flr0_, "FindDevice", - {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); + TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {{"_target", "/device:CPU:1"}}, + &handle)); Tensor y; FunctionLibraryRuntime::Options opts; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index c4114ff873..142ff2339b 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -30,15 +30,15 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, DistributedFunctionLibraryRuntime* parent) - : lib_def_(lib_def), parent_(parent) { + : device_mgr_(device_mgr), lib_def_(lib_def), parent_(parent) { if (device_mgr == nullptr) { - flr_map_[kDefaultFLRDevice] = + flr_map_[nullptr] = NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options, this); return; } for (Device* d : device_mgr->ListDevices()) { - flr_map_[d->name()] = + flr_map_[d] = NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version, lib_def, optimizer_options, this); } @@ -50,15 +50,15 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator, DistributedFunctionLibraryRuntime* parent) - : lib_def_(lib_def), parent_(parent) { + : device_mgr_(device_mgr), lib_def_(lib_def), parent_(parent) { if (device_mgr == nullptr) { - flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime( + flr_map_[nullptr] = NewFunctionLibraryRuntime( nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options, std::move(custom_kernel_creator), this); return; } for (Device* d : device_mgr->ListDevices()) { - flr_map_[d->name()] = NewFunctionLibraryRuntime( + flr_map_[d] = NewFunctionLibraryRuntime( device_mgr, env, d, graph_def_version, lib_def, optimizer_options, custom_kernel_creator, this); } @@ -163,17 +163,19 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( const string& device_name) { - string clean_device_name; + Device* device = nullptr; if (device_name != kDefaultFLRDevice) { - clean_device_name = DeviceNameUtils::CanonicalizeDeviceName(device_name); - } else { - clean_device_name = device_name; + if (!device_mgr_->LookupDevice(device_name, &device).ok()) { + LOG(ERROR) << "Could not find device: " << device_name; + return nullptr; + } } - if (flr_map_.find(clean_device_name) == flr_map_.end()) { + const auto& iter = flr_map_.find(device); + if (iter == flr_map_.end()) { LOG(ERROR) << "Could not find device: " << device_name; return nullptr; } - return flr_map_[clean_device_name].get(); + return iter->second.get(); } FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle( diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 85717739d0..a267bc3601 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -153,12 +153,13 @@ class ProcessFunctionLibraryRuntime { : target_device(target_device), local_handle(local_handle) {} }; + const DeviceMgr* const device_mgr_; const FunctionLibraryDefinition* lib_def_; // Holds all the function invocations here. std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ GUARDED_BY(mu_); std::vector<FunctionData> function_data_ GUARDED_BY(mu_); - std::unordered_map<string, std::unique_ptr<FunctionLibraryRuntime>> flr_map_; + std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_; DistributedFunctionLibraryRuntime* const parent_; }; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index cb416603be..6bc8f980c7 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -92,12 +92,32 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { IntraProcessRendezvous* rendezvous_; }; +TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) { + FunctionDefLibrary proto; + std::unique_ptr<FunctionLibraryDefinition> lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), proto)); + OptimizerOptions opts; + std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr( + new ProcessFunctionLibraryRuntime( + nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION, + lib_def.get(), opts, nullptr /* cluster_flr */)); + FunctionLibraryRuntime* flr = + proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + EXPECT_NE(flr, nullptr); +} + TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { Init({}); FunctionLibraryRuntime* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0"); EXPECT_NE(flr, nullptr); EXPECT_EQ(flr->device(), devices_[0]); + flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0"); + EXPECT_NE(flr, nullptr); + EXPECT_EQ(flr->device(), devices_[0]); + flr = proc_flr_->GetFLR("/device:CPU:0"); + EXPECT_NE(flr, nullptr); + EXPECT_EQ(flr->device(), devices_[0]); flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1"); EXPECT_NE(flr, nullptr); EXPECT_EQ(flr->device(), devices_[1]); @@ -213,13 +233,11 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", opts, - {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:0"}}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"}, TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", opts, - {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"}, TensorShape({}))); |