diff options
author | 2018-01-04 14:48:09 -0800 | |
---|---|---|
committer | 2018-01-04 14:51:51 -0800 | |
commit | 19bbc31eee8b81bce6eb08b7ada539943fac6014 (patch) | |
tree | 079f8dcc4b87783f3c356011e5708e11a33814d4 /tensorflow | |
parent | b639608a6da140f720636582022a575d7c8a7650 (diff) |
Add `FunctionLibraryRuntime::InstantiateOptions` struct.
This new struct allows optional arguments to be passed to the
`FunctionLibraryRuntime::Instantiate()` API. The new struct is now
used to configure the target device for a function instantiation
(instead of an attr).
PiperOrigin-RevId: 180848930
Diffstat (limited to 'tensorflow')
14 files changed, 150 insertions, 171 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index c9a3537c70..742835c964 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -83,11 +83,8 @@ class FunctionBufferingResource : public ResourceBase { return Status::OK(); } AttrValueMap attr_values = func_.attr(); - AttrValue v; - v.set_s(target_device_); - AddAttr("_target", v, &attr_values); - - return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_); + return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), + {target_device_}, &handle_); } // Returns true if we've got to the end of the sequence and exhausted the diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 51d7f98f72..286266a485 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -152,6 +152,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { ~FunctionLibraryRuntimeImpl() override; Status Instantiate(const string& function_name, AttrSlice attrs, + const InstantiateOptions& options, Handle* handle) override; Status ReleaseHandle(Handle handle) override; @@ -223,7 +224,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { Status GetOrCreateItem(Handle handle, Item** item); Status InstantiateSymbolicGradient(const NameAttrList& func, FunctionBody** g_body); - bool IsLocalTarget(const AttrSlice& attrs); + bool IsLocalTarget(const InstantiateOptions& options); AttrValueMap FixAttrs(const AttrSlice& attrs); void RunRemote(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, @@ -352,7 +353,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, // Try to instantiate this function for the func/attr. Maybe it's // cached already. Handle handle; - TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); + TF_RETURN_IF_ERROR( + Instantiate(ndef.op(), AttrSlice(&ndef.attr()), {}, &handle)); const FunctionBody* fbody = GetFunctionBody(handle); CHECK_NOTNULL(fbody); @@ -411,7 +413,7 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( // f is a user-defined function. Handle f_handle; TF_RETURN_IF_ERROR( - Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle)); + Instantiate(func.name(), AttrSlice(&func.attr()), {}, &f_handle)); const FunctionBody* f_body = GetFunctionBody(f_handle); CHECK_NOTNULL(f_body); *g_body = SymbolicGradient(*f_body); @@ -419,42 +421,25 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( return Status::OK(); } -bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) { +bool FunctionLibraryRuntimeImpl::IsLocalTarget( + const InstantiateOptions& options) { if (device_ == nullptr) return true; - string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); - if (target.empty()) return true; + if (options.target.empty()) return true; Device* target_device; - if (!device_mgr_->LookupDevice(target, &target_device).ok()) { + if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) { return false; } return target_device == device_; } -AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) { - AttrValueMap value_map; - for (auto it : attrs) { - value_map[it.first] = it.second; - } - if (attrs.Find("_target") != nullptr) { - return value_map; - } - AttrValue v; - v.set_s(device_name_); - AddAttr("_target", v, &value_map); - return value_map; -} - -Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, - AttrSlice attrs, - Handle* handle) { - AttrValueMap value_map = FixAttrs(attrs); - AttrSlice new_attrs(&value_map); - - if (!IsLocalTarget(new_attrs)) { - return parent_->Instantiate(function_name, new_attrs, handle); +Status FunctionLibraryRuntimeImpl::Instantiate( + const string& function_name, AttrSlice attrs, + const InstantiateOptions& options, Handle* handle) { + if (!IsLocalTarget(options)) { + return parent_->Instantiate(function_name, attrs, options, handle); } - const string key = Canonicalize(function_name, new_attrs); + const string key = Canonicalize(function_name, attrs, options); *handle = parent_->GetHandle(key); if (*handle != kInvalidHandle) { return Status::OK(); @@ -463,7 +448,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, Status s; FunctionBody* fbody = nullptr; if (function_name == kGradientOp) { - const AttrValue* f = new_attrs.Find(kFuncAttr); + const AttrValue* f = attrs.Find(kFuncAttr); if (f == nullptr) { return errors::InvalidArgument("SymbolicGradient is missing attr: f"); } @@ -473,7 +458,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, } const string grad = lib_def_->FindGradient(func.name()); if (!grad.empty()) { - return Instantiate(grad, AttrSlice(&func.attr()), handle); + return Instantiate(grad, AttrSlice(&func.attr()), options, handle); } TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody)); } else { @@ -481,7 +466,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, if (fdef == nullptr) { return errors::NotFound("Function ", function_name, " is not defined."); } - TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, new_attrs, &fbody)); + TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody)); } { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index d4181ff48c..2dacacea7b 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -191,11 +191,14 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { Status Instantiate(FunctionLibraryRuntime* flr, const string& name, test::function::Attrs attrs, FunctionLibraryRuntime::Handle* handle) { - Status status = flr->Instantiate(name, attrs, handle); - if (!status.ok()) { - return status; - } - return Status::OK(); + return flr->Instantiate(name, attrs, handle); + } + + Status Instantiate(FunctionLibraryRuntime* flr, const string& name, + test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::Handle* handle) { + return flr->Instantiate(name, attrs, options, handle); } Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name, @@ -1088,8 +1091,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { Init({test::function::FindDevice()}); FunctionLibraryRuntime::Handle handle; - TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {{"_target", "/device:CPU:1"}}, - &handle)); + TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, {"/device:CPU:1"}, &handle)); Tensor y; FunctionLibraryRuntime::Options opts; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 53a14121d4..12947e284a 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -88,16 +88,6 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( std::move(custom_kernel_creator), nullptr /* cluster_flr */) {} /* static */ -string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( - const AttrSlice& attrs) { - const AttrValue* value; - if (!attrs.Find("_target", &value).ok()) { - return ""; - } - return DeviceNameUtils::CanonicalizeDeviceName(value->s()); -} - -/* static */ Status ProcessFunctionLibraryRuntime::SendTensors( const string& source_device, const string& target_device, const string& key_prefix, int64 src_incarnation, @@ -240,22 +230,23 @@ string ProcessFunctionLibraryRuntime::GetDeviceName( Status ProcessFunctionLibraryRuntime::Instantiate( const string& function_name, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { *handle = kInvalidHandle; - string target = ObtainFunctionTarget(attrs); - FunctionLibraryRuntime* flr = GetFLR(target); + FunctionLibraryRuntime* flr = GetFLR(options.target); if (flr != nullptr) { - return flr->Instantiate(function_name, attrs, handle); + return flr->Instantiate(function_name, attrs, options, handle); } if (parent_ == nullptr) { return errors::Internal( - "Currently don't support instantiating functions on device: ", target); + "Currently don't support instantiating functions on device: ", + options.target); } FunctionLibraryRuntime::Handle cluster_handle; - TF_RETURN_IF_ERROR( - parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle)); + TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs, + options, &cluster_handle)); string function_key = Canonicalize(function_name, attrs); - *handle = AddHandle(function_key, target, cluster_handle); + *handle = AddHandle(function_key, options.target, cluster_handle); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 3aa7b87286..38003b7726 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -53,11 +53,6 @@ class ProcessFunctionLibraryRuntime { const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator); - // Given a list of attrs on a function, extracts the "_target" attribute which - // indicates which device to run the function on. If it can't find the _target - // attribute, returns "". Canonicalizes the device name. - static string ObtainFunctionTarget(const AttrSlice& attrs); - // Sends `tensors_to_send` from `source_device` to `target_device` using // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the // Rendezvous. `device_context` should be the DeviceContext of the device @@ -121,6 +116,7 @@ class ProcessFunctionLibraryRuntime { // Allows for function_name to be instantiated on different devices // as specified in attrs. Status Instantiate(const string& function_name, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle); // Delegates to the local FLR that owns state corresponding to `handle` and diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 270e46dfe9..f11b7a851f 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -49,10 +49,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } Status Run(const string& name, FunctionLibraryRuntime::Options opts, - test::function::Attrs attrs, const std::vector<Tensor>& args, - std::vector<Tensor*> rets) { + test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, + const std::vector<Tensor>& args, std::vector<Tensor*> rets) { FunctionLibraryRuntime::Handle handle; - Status status = proc_flr_->Instantiate(name, attrs, &handle); + Status status = + proc_flr_->Instantiate(name, attrs, instantiate_opts, &handle); if (!status.ok()) { return status; } @@ -142,21 +144,6 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { rendezvous_->Unref(); } -TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { - AttrSlice empty_attrs; - string target = - ProcessFunctionLibraryRuntime::ObtainFunctionTarget(empty_attrs); - EXPECT_EQ("", target); - - AttrValueMap attr_values; - AttrValue v; - v.set_s("/job:a/replica:0/task:0/cpu:1"); - AddAttr("_target", v, &attr_values); - AttrSlice attrs(&attr_values); - target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); - EXPECT_EQ("/job:a/replica:0/task:0/device:CPU:1", target); -} - TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) { Init({}); int64 incarnation; @@ -178,10 +165,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { opts.remote_execution = true; auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK( - Run("XTimesTwo", opts, - {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, - {&y})); + TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}}, + {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); rendezvous_->Unref(); } @@ -193,8 +178,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", opts, - {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y})); + TF_CHECK_OK( + Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:0"}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"}, TensorShape({}))); @@ -209,15 +194,11 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK( - Run("XTimesTwo", opts, - {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, - {&y})); + TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}}, + {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); - TF_CHECK_OK( - Run("XTimesFour", opts, - {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, - {&y})); + TF_CHECK_OK(Run("XTimesFour", opts, {{"T", DT_FLOAT}}, + {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16})); rendezvous_->Unref(); } @@ -229,13 +210,13 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", opts, - {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); + TF_CHECK_OK( + Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"}, TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", opts, - {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); + TF_CHECK_OK( + Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"}, TensorShape({}))); @@ -249,11 +230,13 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:0"}}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, {}, + {"/job:a/replica:0/task:0/device:CPU:0"}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"}, TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:1"}}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, {}, + {"/job:a/replica:0/task:0/device:CPU:1"}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"}, TensorShape({}))); diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index d84b69d06b..3a8d591236 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -26,10 +26,10 @@ namespace tensorflow { /* static */ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( - const OpDef& sig, AttrSlice attrs, GraphDef* g, + const OpDef& sig, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g, std::vector<string>* send_keys, std::vector<string>* recv_keys) { - const string& target = - ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); + const string& target = options.target; // Construct recv nodes for each input argument. int i = 0; for (const auto& in : sig.input_arg()) { @@ -119,16 +119,16 @@ ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() { Status ClusterFunctionLibraryRuntime::Instantiate( const string& function_name, const FunctionLibraryDefinition& lib_def, - AttrSlice attrs, FunctionLibraryRuntime::LocalHandle* handle) { - const string& target = - ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); - WorkerInterface* wi = worker_session_->worker_cache->CreateWorker(target); + AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::LocalHandle* handle) { + WorkerInterface* wi = + worker_session_->worker_cache->CreateWorker(options.target); if (wi == nullptr) { std::vector<string> workers; worker_session_->worker_cache->ListWorkers(&workers); return errors::InvalidArgument( - "Could not find worker with target: ", target, + "Could not find worker with target: ", options.target, " Available workers: ", str_util::Join(workers, ", ")); } @@ -137,8 +137,8 @@ Status ClusterFunctionLibraryRuntime::Instantiate( const OpDef& sig = fdef->signature(); GraphDef gdef; std::vector<string> send_keys, recv_keys; - TF_RETURN_IF_ERROR( - ConstructFunctionGraph(sig, attrs, &gdef, &send_keys, &recv_keys)); + TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, &gdef, + &send_keys, &recv_keys)); *gdef.mutable_library() = lib_def.ToProto(); RegisterGraphRequest req; @@ -152,8 +152,8 @@ Status ClusterFunctionLibraryRuntime::Instantiate( mutex_lock l(mu_); *handle = function_data_.size(); - function_data_.push_back( - FunctionData(resp.graph_handle(), target, wi, send_keys, recv_keys)); + function_data_.push_back(FunctionData(resp.graph_handle(), options.target, wi, + send_keys, recv_keys)); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h index dd4ea68f57..3deb80dff7 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h @@ -34,6 +34,7 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { Status Instantiate(const string& function_name, const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::LocalHandle* handle) override; void Run(const FunctionLibraryRuntime::Options& opts, @@ -42,10 +43,10 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { FunctionLibraryRuntime::DoneCallback done) override; private: - static Status ConstructFunctionGraph(const OpDef& sig, AttrSlice attrs, - GraphDef* g, - std::vector<string>* send_keys, - std::vector<string>* recv_keys); + static Status ConstructFunctionGraph( + const OpDef& sig, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g, + std::vector<string>* send_keys, std::vector<string>* recv_keys); friend class ClusterFunctionLibraryRuntimeTest; mutable mutex mu_; diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc index 6dd8b9ec73..98512bce18 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc @@ -47,30 +47,31 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { new ClusterFunctionLibraryRuntime(worker_session_.get())); } - Status ConstructFunctionGraphHelper(const OpDef& sig, - test::function::Attrs attrs, GraphDef* g, - std::vector<string>* send_keys, - std::vector<string>* recv_keys) { + Status ConstructFunctionGraphHelper( + const OpDef& sig, test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g, + std::vector<string>* send_keys, std::vector<string>* recv_keys) { return ClusterFunctionLibraryRuntime::ConstructFunctionGraph( - sig, attrs, g, send_keys, recv_keys); + sig, attrs, options, g, send_keys, recv_keys); } Status Instantiate(const string& function_name, const FunctionLibraryDefinition& lib_def, test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::LocalHandle* local_handle) { - return cluster_flr_->Instantiate(function_name, lib_def, attrs, + return cluster_flr_->Instantiate(function_name, lib_def, attrs, options, local_handle); } - Status InstantiateAndRun(const string& function_name, - const FunctionLibraryDefinition& lib_def, - test::function::Attrs attrs, - const std::vector<Tensor>& args, - std::vector<Tensor*> rets) { + Status InstantiateAndRun( + const string& function_name, const FunctionLibraryDefinition& lib_def, + test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + const std::vector<Tensor>& args, std::vector<Tensor*> rets) { FunctionLibraryRuntime::LocalHandle handle; - TF_RETURN_IF_ERROR( - cluster_flr_->Instantiate(function_name, lib_def, attrs, &handle)); + TF_RETURN_IF_ERROR(cluster_flr_->Instantiate(function_name, lib_def, attrs, + options, &handle)); Notification done; FunctionLibraryRuntime::Options opts; @@ -103,9 +104,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) { GraphDef actual; std::vector<string> send_keys, recv_keys; TF_CHECK_OK(ConstructFunctionGraphHelper( - test::function::Swap().signature(), - {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, &actual, - &send_keys, &recv_keys)); + test::function::Swap().signature(), {{"T", DT_FLOAT}}, + {"/job:a/replica:0/task:0/device:CPU:0"}, &actual, &send_keys, + &recv_keys)); GraphDef expected; protobuf::TextFormat::ParseFromString(R"( node { @@ -205,7 +206,7 @@ node { attr { key: "_target" value { - s: "/job:a/replica:0/task:0/cpu:0" + s: "/job:a/replica:0/task:0/device:CPU:0" } } } @@ -309,9 +310,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, DISABLED_InstantiateAndRun) { Tensor y; auto x = test::AsTensor<int32>({1, 2, 3, 4}); - TF_EXPECT_OK(InstantiateAndRun( - "XTimesTwoInt32", lib_def, - {{"_target", "/job:localhost/replica:0/task:1/cpu:0"}}, {x}, {&y})); + TF_EXPECT_OK(InstantiateAndRun("XTimesTwoInt32", lib_def, {}, + {"/job:localhost/replica:0/task:1/cpu:0"}, {x}, + {&y})); test::ExpectTensorEqual<int32>(y, test::AsTensor<int32>({2, 4, 6, 8})); } @@ -324,10 +325,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, Tensor y1, y2; auto x1 = test::AsTensor<float>({1, 2, 3, 4}); auto x2 = test::AsTensor<float>({4, 3, 2, 1}); - TF_EXPECT_OK(InstantiateAndRun( - "Swap", lib_def, - {{"T", DT_FLOAT}, {"_target", "/job:localhost/replica:0/task:1/cpu:0"}}, - {x1, x2}, {&y1, &y2})); + TF_EXPECT_OK(InstantiateAndRun("Swap", lib_def, {{"T", DT_FLOAT}}, + {"/job:localhost/replica:0/task:1/cpu:0"}, + {x1, x2}, {&y1, &y2})); test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({4, 3, 2, 1})); test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({1, 2, 3, 4})); } diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index d757e962e5..783015481a 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -795,12 +795,17 @@ uint64 FunctionDefHash(const FunctionDef& fdef) { return h; } -string Canonicalize(const string& funcname, AttrSlice attrs) { +string Canonicalize(const string& funcname, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options) { std::vector<string> entries; - entries.reserve(attrs.size()); + entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1)); for (auto p : attrs) { entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); } + if (!options.target.empty()) { + entries.push_back( + strings::StrCat("_target", "=", str_util::CEscape(options.target))); + } std::sort(entries.begin(), entries.end()); return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 1a579ab631..e5d0e49dbb 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -234,15 +234,6 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); // same. uint64 FunctionDefHash(const FunctionDef& fdef); -// Returns a canonicalized string for the instantiation of the -// function of the given "name" and attributes "attrs". -// -// The returned string is guaranteed to be stable within one address -// space. But it may be change as the implementation -// evolves. Therefore, it should not be persisted or compared across -// address spaces. -string Canonicalize(const string& funcname, AttrSlice attrs); - class CallFrameInterface { public: virtual ~CallFrameInterface() {} @@ -418,9 +409,23 @@ class FunctionLibraryRuntime { // // Returns OK and fills in "handle" if the instantiation succeeds. // Otherwise returns an error and "handle" is undefined. + struct InstantiateOptions { + // The canonical device name of the device on which the function + // should be instantiated. If empty, the function will be + // instantiated on the local device. + string target; + + // TODO(b/70352992): Add an API for allowing a different + // FunctionLibraryDefinition to be overlaid on this runtime's library. + }; typedef uint64 Handle; virtual Status Instantiate(const string& function_name, AttrSlice attrs, + const InstantiateOptions& options, Handle* handle) = 0; + Status Instantiate(const string& function_name, AttrSlice attrs, + Handle* handle) { + return Instantiate(function_name, attrs, {}, handle); + } // Releases state associated with the handle. virtual Status ReleaseHandle(Handle handle) = 0; @@ -502,6 +507,19 @@ class FunctionLibraryRuntime { typedef uint64 LocalHandle; }; +// Returns a canonicalized string for the instantiation of the +// function of the given "name", attributes "attrs", and "options". +// +// The returned string is guaranteed to be stable within one address +// space. But it may be change as the implementation +// evolves. Therefore, it should not be persisted or compared across +// address spaces. +string Canonicalize(const string& funcname, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options); +inline string Canonicalize(const string& funcname, AttrSlice attrs) { + return Canonicalize(funcname, attrs, {}); +} + const FunctionLibraryRuntime::Handle kInvalidHandle = -1; const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&, @@ -514,10 +532,11 @@ class DistributedFunctionLibraryRuntime { virtual ~DistributedFunctionLibraryRuntime() {} // The _target attr in attrs determines where the function is instantiated. - virtual Status Instantiate(const string& function_name, - const FunctionLibraryDefinition& lib_def, - AttrSlice attrs, - FunctionLibraryRuntime::LocalHandle* handle) = 0; + virtual Status Instantiate( + const string& function_name, const FunctionLibraryDefinition& lib_def, + AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::LocalHandle* handle) = 0; // opts.runner isn't used for execution. virtual void Run(const FunctionLibraryRuntime::Options& opts, diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index f469f41e06..facac10f66 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -296,21 +296,19 @@ class RemoteCallOp : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { const Tensor* target; OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); - AttrValueMap attr_values = func_.attr(); - AttrValue v; const string& target_device = DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()()); - v.set_s(target_device); - AddAttr("_target", v, &attr_values); FunctionLibraryRuntime* lib = ctx->function_library(); OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library is provided."), done); + AttrValueMap attr_values = func_.attr(); FunctionLibraryRuntime::Handle handle; - OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle), - done); + OP_REQUIRES_OK_ASYNC(ctx, + lib->Instantiate(func_.name(), AttrSlice(&attr_values), + {target_device}, &handle), + done); OpInputList arguments; OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index 909150eb6a..00c38eac10 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -154,7 +154,8 @@ class OnSessionInitRequest(object): sess: A tensorflow Session object. """ - _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) + _check_type(sess, (session.SessionInterface, + monitored_session.MonitoredSession)) self.session = sess @@ -358,7 +359,8 @@ class BaseDebugWrapperSession(session.SessionInterface): NotImplementedError: If a non-DirectSession sess object is received. """ - _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) + _check_type(sess, (session.SessionInterface, + monitored_session.MonitoredSession)) # The session being wrapped. self._sess = sess diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py index 73e08ce7d5..5240e0d0ad 100644 --- a/tensorflow/python/debug/wrappers/framework_test.py +++ b/tensorflow/python/debug/wrappers/framework_test.py @@ -271,9 +271,9 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase): def testSessionInitInvalidSessionType(self): """Attempt to wrap a non-Session-type object should cause an exception.""" - wrapper = TestDebugWrapperSessionBadAction(self._sess) + sess = "not a session" with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"): - TestDebugWrapperSessionBadAction(wrapper) + TestDebugWrapperSessionBadAction(sess) def testSessionInitBadActionValue(self): with self.assertRaisesRegexp( |