diff options
author | James Qin <jamesqin@google.com> | 2018-01-05 14:55:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-05 14:58:51 -0800 |
commit | cc94f3888eb008fd78a2d95e30ecf76fbce0e0af (patch) | |
tree | cc126be226c801b8fc374b5ff4050deace0c098d | |
parent | c152d4bbac33d9dc10f7278226aaeb42fce799b5 (diff) |
Automated g4 rollback of changelist 180848930
PiperOrigin-RevId: 180979141
14 files changed, 171 insertions, 150 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 742835c964..c9a3537c70 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -83,8 +83,11 @@ class FunctionBufferingResource : public ResourceBase { return Status::OK(); } AttrValueMap attr_values = func_.attr(); - return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), - {target_device_}, &handle_); + AttrValue v; + v.set_s(target_device_); + AddAttr("_target", v, &attr_values); + + return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_); } // Returns true if we've got to the end of the sequence and exhausted the diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 286266a485..51d7f98f72 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -152,7 +152,6 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { ~FunctionLibraryRuntimeImpl() override; Status Instantiate(const string& function_name, AttrSlice attrs, - const InstantiateOptions& options, Handle* handle) override; Status ReleaseHandle(Handle handle) override; @@ -224,7 +223,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { Status GetOrCreateItem(Handle handle, Item** item); Status InstantiateSymbolicGradient(const NameAttrList& func, FunctionBody** g_body); - bool IsLocalTarget(const InstantiateOptions& options); + bool IsLocalTarget(const AttrSlice& attrs); AttrValueMap FixAttrs(const AttrSlice& attrs); void RunRemote(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, @@ -353,8 +352,7 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, // Try to instantiate this function for the func/attr. Maybe it's // cached already. Handle handle; - TF_RETURN_IF_ERROR( - Instantiate(ndef.op(), AttrSlice(&ndef.attr()), {}, &handle)); + TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); const FunctionBody* fbody = GetFunctionBody(handle); CHECK_NOTNULL(fbody); @@ -413,7 +411,7 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( // f is a user-defined function. Handle f_handle; TF_RETURN_IF_ERROR( - Instantiate(func.name(), AttrSlice(&func.attr()), {}, &f_handle)); + Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle)); const FunctionBody* f_body = GetFunctionBody(f_handle); CHECK_NOTNULL(f_body); *g_body = SymbolicGradient(*f_body); @@ -421,25 +419,42 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( return Status::OK(); } -bool FunctionLibraryRuntimeImpl::IsLocalTarget( - const InstantiateOptions& options) { +bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) { if (device_ == nullptr) return true; - if (options.target.empty()) return true; + string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); + if (target.empty()) return true; Device* target_device; - if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) { + if (!device_mgr_->LookupDevice(target, &target_device).ok()) { return false; } return target_device == device_; } -Status FunctionLibraryRuntimeImpl::Instantiate( - const string& function_name, AttrSlice attrs, - const InstantiateOptions& options, Handle* handle) { - if (!IsLocalTarget(options)) { - return parent_->Instantiate(function_name, attrs, options, handle); +AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) { + AttrValueMap value_map; + for (auto it : attrs) { + value_map[it.first] = it.second; + } + if (attrs.Find("_target") != nullptr) { + return value_map; + } + AttrValue v; + v.set_s(device_name_); + AddAttr("_target", v, &value_map); + return value_map; +} + +Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, + AttrSlice attrs, + Handle* handle) { + AttrValueMap value_map = FixAttrs(attrs); + AttrSlice new_attrs(&value_map); + + if (!IsLocalTarget(new_attrs)) { + return parent_->Instantiate(function_name, new_attrs, handle); } - const string key = Canonicalize(function_name, attrs, options); + const string key = Canonicalize(function_name, new_attrs); *handle = parent_->GetHandle(key); if (*handle != kInvalidHandle) { return Status::OK(); @@ -448,7 +463,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( Status s; FunctionBody* fbody = nullptr; if (function_name == kGradientOp) { - const AttrValue* f = attrs.Find(kFuncAttr); + const AttrValue* f = new_attrs.Find(kFuncAttr); if (f == nullptr) { return errors::InvalidArgument("SymbolicGradient is missing attr: f"); } @@ -458,7 +473,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( } const string grad = lib_def_->FindGradient(func.name()); if (!grad.empty()) { - return Instantiate(grad, AttrSlice(&func.attr()), options, handle); + return Instantiate(grad, AttrSlice(&func.attr()), handle); } TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody)); } else { @@ -466,7 +481,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( if (fdef == nullptr) { return errors::NotFound("Function ", function_name, " is not defined."); } - TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody)); + TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, new_attrs, &fbody)); } { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 2dacacea7b..d4181ff48c 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -191,14 +191,11 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { Status Instantiate(FunctionLibraryRuntime* flr, const string& name, test::function::Attrs attrs, FunctionLibraryRuntime::Handle* handle) { - return flr->Instantiate(name, attrs, handle); - } - - Status Instantiate(FunctionLibraryRuntime* flr, const string& name, - test::function::Attrs attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, - FunctionLibraryRuntime::Handle* handle) { - return flr->Instantiate(name, attrs, options, handle); + Status status = flr->Instantiate(name, attrs, handle); + if (!status.ok()) { + return status; + } + return Status::OK(); } Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name, @@ -1091,7 +1088,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { Init({test::function::FindDevice()}); FunctionLibraryRuntime::Handle handle; - TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, {"/device:CPU:1"}, &handle)); + TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {{"_target", "/device:CPU:1"}}, + &handle)); Tensor y; FunctionLibraryRuntime::Options opts; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 12947e284a..53a14121d4 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -88,6 +88,16 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( std::move(custom_kernel_creator), nullptr /* cluster_flr */) {} /* static */ +string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( + const AttrSlice& attrs) { + const AttrValue* value; + if (!attrs.Find("_target", &value).ok()) { + return ""; + } + return DeviceNameUtils::CanonicalizeDeviceName(value->s()); +} + +/* static */ Status ProcessFunctionLibraryRuntime::SendTensors( const string& source_device, const string& target_device, const string& key_prefix, int64 src_incarnation, @@ -230,23 +240,22 @@ string ProcessFunctionLibraryRuntime::GetDeviceName( Status ProcessFunctionLibraryRuntime::Instantiate( const string& function_name, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { *handle = kInvalidHandle; - FunctionLibraryRuntime* flr = GetFLR(options.target); + string target = ObtainFunctionTarget(attrs); + FunctionLibraryRuntime* flr = GetFLR(target); if (flr != nullptr) { - return flr->Instantiate(function_name, attrs, options, handle); + return flr->Instantiate(function_name, attrs, handle); } if (parent_ == nullptr) { return errors::Internal( - "Currently don't support instantiating functions on device: ", - options.target); + "Currently don't support instantiating functions on device: ", target); } FunctionLibraryRuntime::Handle cluster_handle; - TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs, - options, &cluster_handle)); + TF_RETURN_IF_ERROR( + parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle)); string function_key = Canonicalize(function_name, attrs); - *handle = AddHandle(function_key, options.target, cluster_handle); + *handle = AddHandle(function_key, target, cluster_handle); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 38003b7726..3aa7b87286 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -53,6 +53,11 @@ class ProcessFunctionLibraryRuntime { const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator); + // Given a list of attrs on a function, extracts the "_target" attribute which + // indicates which device to run the function on. If it can't find the _target + // attribute, returns "". Canonicalizes the device name. + static string ObtainFunctionTarget(const AttrSlice& attrs); + // Sends `tensors_to_send` from `source_device` to `target_device` using // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the // Rendezvous. `device_context` should be the DeviceContext of the device @@ -116,7 +121,6 @@ class ProcessFunctionLibraryRuntime { // Allows for function_name to be instantiated on different devices // as specified in attrs. Status Instantiate(const string& function_name, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle); // Delegates to the local FLR that owns state corresponding to `handle` and diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index f11b7a851f..270e46dfe9 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -49,12 +49,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } Status Run(const string& name, FunctionLibraryRuntime::Options opts, - test::function::Attrs attrs, - const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, - const std::vector<Tensor>& args, std::vector<Tensor*> rets) { + test::function::Attrs attrs, const std::vector<Tensor>& args, + std::vector<Tensor*> rets) { FunctionLibraryRuntime::Handle handle; - Status status = - proc_flr_->Instantiate(name, attrs, instantiate_opts, &handle); + Status status = proc_flr_->Instantiate(name, attrs, &handle); if (!status.ok()) { return status; } @@ -144,6 +142,21 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { rendezvous_->Unref(); } +TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { + AttrSlice empty_attrs; + string target = + ProcessFunctionLibraryRuntime::ObtainFunctionTarget(empty_attrs); + EXPECT_EQ("", target); + + AttrValueMap attr_values; + AttrValue v; + v.set_s("/job:a/replica:0/task:0/cpu:1"); + AddAttr("_target", v, &attr_values); + AttrSlice attrs(&attr_values); + target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); + EXPECT_EQ("/job:a/replica:0/task:0/device:CPU:1", target); +} + TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) { Init({}); int64 incarnation; @@ -165,8 +178,10 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { opts.remote_execution = true; auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}}, - {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y})); + TF_CHECK_OK( + Run("XTimesTwo", opts, + {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, + {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); rendezvous_->Unref(); } @@ -178,8 +193,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK( - Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:0"}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"}, TensorShape({}))); @@ -194,11 +209,15 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}}, - {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y})); + TF_CHECK_OK( + Run("XTimesTwo", opts, + {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, + {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); - TF_CHECK_OK(Run("XTimesFour", opts, {{"T", DT_FLOAT}}, - {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y})); + TF_CHECK_OK( + Run("XTimesFour", opts, + {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, + {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16})); rendezvous_->Unref(); } @@ -210,13 +229,13 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK( - Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"}, TensorShape({}))); - TF_CHECK_OK( - Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"}, TensorShape({}))); @@ -230,13 +249,11 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { opts.rendezvous = rendezvous_; opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", opts, {}, - {"/job:a/replica:0/task:0/device:CPU:0"}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:0"}}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"}, TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", opts, {}, - {"/job:a/replica:0/task:0/device:CPU:1"}, {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"}, TensorShape({}))); diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index 3a8d591236..d84b69d06b 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -26,10 +26,10 @@ namespace tensorflow { /* static */ Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( - const OpDef& sig, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g, + const OpDef& sig, AttrSlice attrs, GraphDef* g, std::vector<string>* send_keys, std::vector<string>* recv_keys) { - const string& target = options.target; + const string& target = + ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); // Construct recv nodes for each input argument. int i = 0; for (const auto& in : sig.input_arg()) { @@ -119,16 +119,16 @@ ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() { Status ClusterFunctionLibraryRuntime::Instantiate( const string& function_name, const FunctionLibraryDefinition& lib_def, - AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, - FunctionLibraryRuntime::LocalHandle* handle) { - WorkerInterface* wi = - worker_session_->worker_cache->CreateWorker(options.target); + AttrSlice attrs, FunctionLibraryRuntime::LocalHandle* handle) { + const string& target = + ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); + WorkerInterface* wi = worker_session_->worker_cache->CreateWorker(target); if (wi == nullptr) { std::vector<string> workers; worker_session_->worker_cache->ListWorkers(&workers); return errors::InvalidArgument( - "Could not find worker with target: ", options.target, + "Could not find worker with target: ", target, " Available workers: ", str_util::Join(workers, ", ")); } @@ -137,8 +137,8 @@ Status ClusterFunctionLibraryRuntime::Instantiate( const OpDef& sig = fdef->signature(); GraphDef gdef; std::vector<string> send_keys, recv_keys; - TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, &gdef, - &send_keys, &recv_keys)); + TF_RETURN_IF_ERROR( + ConstructFunctionGraph(sig, attrs, &gdef, &send_keys, &recv_keys)); *gdef.mutable_library() = lib_def.ToProto(); RegisterGraphRequest req; @@ -152,8 +152,8 @@ Status ClusterFunctionLibraryRuntime::Instantiate( mutex_lock l(mu_); *handle = function_data_.size(); - function_data_.push_back(FunctionData(resp.graph_handle(), options.target, wi, - send_keys, recv_keys)); + function_data_.push_back( + FunctionData(resp.graph_handle(), target, wi, send_keys, recv_keys)); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h index 3deb80dff7..dd4ea68f57 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h @@ -34,7 +34,6 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { Status Instantiate(const string& function_name, const FunctionLibraryDefinition& lib_def, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::LocalHandle* handle) override; void Run(const FunctionLibraryRuntime::Options& opts, @@ -43,10 +42,10 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { FunctionLibraryRuntime::DoneCallback done) override; private: - static Status ConstructFunctionGraph( - const OpDef& sig, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g, - std::vector<string>* send_keys, std::vector<string>* recv_keys); + static Status ConstructFunctionGraph(const OpDef& sig, AttrSlice attrs, + GraphDef* g, + std::vector<string>* send_keys, + std::vector<string>* recv_keys); friend class ClusterFunctionLibraryRuntimeTest; mutable mutex mu_; diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc index 98512bce18..6dd8b9ec73 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc @@ -47,31 +47,30 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { new ClusterFunctionLibraryRuntime(worker_session_.get())); } - Status ConstructFunctionGraphHelper( - const OpDef& sig, test::function::Attrs attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g, - std::vector<string>* send_keys, std::vector<string>* recv_keys) { + Status ConstructFunctionGraphHelper(const OpDef& sig, + test::function::Attrs attrs, GraphDef* g, + std::vector<string>* send_keys, + std::vector<string>* recv_keys) { return ClusterFunctionLibraryRuntime::ConstructFunctionGraph( - sig, attrs, options, g, send_keys, recv_keys); + sig, attrs, g, send_keys, recv_keys); } Status Instantiate(const string& function_name, const FunctionLibraryDefinition& lib_def, test::function::Attrs attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::LocalHandle* local_handle) { - return cluster_flr_->Instantiate(function_name, lib_def, attrs, options, + return cluster_flr_->Instantiate(function_name, lib_def, attrs, local_handle); } - Status InstantiateAndRun( - const string& function_name, const FunctionLibraryDefinition& lib_def, - test::function::Attrs attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, - const std::vector<Tensor>& args, std::vector<Tensor*> rets) { + Status InstantiateAndRun(const string& function_name, + const FunctionLibraryDefinition& lib_def, + test::function::Attrs attrs, + const std::vector<Tensor>& args, + std::vector<Tensor*> rets) { FunctionLibraryRuntime::LocalHandle handle; - TF_RETURN_IF_ERROR(cluster_flr_->Instantiate(function_name, lib_def, attrs, - options, &handle)); + TF_RETURN_IF_ERROR( + cluster_flr_->Instantiate(function_name, lib_def, attrs, &handle)); Notification done; FunctionLibraryRuntime::Options opts; @@ -104,9 +103,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) { GraphDef actual; std::vector<string> send_keys, recv_keys; TF_CHECK_OK(ConstructFunctionGraphHelper( - test::function::Swap().signature(), {{"T", DT_FLOAT}}, - {"/job:a/replica:0/task:0/device:CPU:0"}, &actual, &send_keys, - &recv_keys)); + test::function::Swap().signature(), + {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, &actual, + &send_keys, &recv_keys)); GraphDef expected; protobuf::TextFormat::ParseFromString(R"( node { @@ -206,7 +205,7 @@ node { attr { key: "_target" value { - s: "/job:a/replica:0/task:0/device:CPU:0" + s: "/job:a/replica:0/task:0/cpu:0" } } } @@ -310,9 +309,9 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, DISABLED_InstantiateAndRun) { Tensor y; auto x = test::AsTensor<int32>({1, 2, 3, 4}); - TF_EXPECT_OK(InstantiateAndRun("XTimesTwoInt32", lib_def, {}, - {"/job:localhost/replica:0/task:1/cpu:0"}, {x}, - {&y})); + TF_EXPECT_OK(InstantiateAndRun( + "XTimesTwoInt32", lib_def, + {{"_target", "/job:localhost/replica:0/task:1/cpu:0"}}, {x}, {&y})); test::ExpectTensorEqual<int32>(y, test::AsTensor<int32>({2, 4, 6, 8})); } @@ -325,9 +324,10 @@ TEST_F(ClusterFunctionLibraryRuntimeTest, Tensor y1, y2; auto x1 = test::AsTensor<float>({1, 2, 3, 4}); auto x2 = test::AsTensor<float>({4, 3, 2, 1}); - TF_EXPECT_OK(InstantiateAndRun("Swap", lib_def, {{"T", DT_FLOAT}}, - {"/job:localhost/replica:0/task:1/cpu:0"}, - {x1, x2}, {&y1, &y2})); + TF_EXPECT_OK(InstantiateAndRun( + "Swap", lib_def, + {{"T", DT_FLOAT}, {"_target", "/job:localhost/replica:0/task:1/cpu:0"}}, + {x1, x2}, {&y1, &y2})); test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({4, 3, 2, 1})); test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({1, 2, 3, 4})); } diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 783015481a..d757e962e5 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -795,17 +795,12 @@ uint64 FunctionDefHash(const FunctionDef& fdef) { return h; } -string Canonicalize(const string& funcname, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options) { +string Canonicalize(const string& funcname, AttrSlice attrs) { std::vector<string> entries; - entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1)); + entries.reserve(attrs.size()); for (auto p : attrs) { entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); } - if (!options.target.empty()) { - entries.push_back( - strings::StrCat("_target", "=", str_util::CEscape(options.target))); - } std::sort(entries.begin(), entries.end()); return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index e5d0e49dbb..1a579ab631 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -234,6 +234,15 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); // same. uint64 FunctionDefHash(const FunctionDef& fdef); +// Returns a canonicalized string for the instantiation of the +// function of the given "name" and attributes "attrs". +// +// The returned string is guaranteed to be stable within one address +// space. But it may be change as the implementation +// evolves. Therefore, it should not be persisted or compared across +// address spaces. +string Canonicalize(const string& funcname, AttrSlice attrs); + class CallFrameInterface { public: virtual ~CallFrameInterface() {} @@ -409,23 +418,9 @@ class FunctionLibraryRuntime { // // Returns OK and fills in "handle" if the instantiation succeeds. // Otherwise returns an error and "handle" is undefined. - struct InstantiateOptions { - // The canonical device name of the device on which the function - // should be instantiated. If empty, the function will be - // instantiated on the local device. - string target; - - // TODO(b/70352992): Add an API for allowing a different - // FunctionLibraryDefinition to be overlaid on this runtime's library. - }; typedef uint64 Handle; virtual Status Instantiate(const string& function_name, AttrSlice attrs, - const InstantiateOptions& options, Handle* handle) = 0; - Status Instantiate(const string& function_name, AttrSlice attrs, - Handle* handle) { - return Instantiate(function_name, attrs, {}, handle); - } // Releases state associated with the handle. virtual Status ReleaseHandle(Handle handle) = 0; @@ -507,19 +502,6 @@ class FunctionLibraryRuntime { typedef uint64 LocalHandle; }; -// Returns a canonicalized string for the instantiation of the -// function of the given "name", attributes "attrs", and "options". -// -// The returned string is guaranteed to be stable within one address -// space. But it may be change as the implementation -// evolves. Therefore, it should not be persisted or compared across -// address spaces. -string Canonicalize(const string& funcname, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options); -inline string Canonicalize(const string& funcname, AttrSlice attrs) { - return Canonicalize(funcname, attrs, {}); -} - const FunctionLibraryRuntime::Handle kInvalidHandle = -1; const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&, @@ -532,11 +514,10 @@ class DistributedFunctionLibraryRuntime { virtual ~DistributedFunctionLibraryRuntime() {} // The _target attr in attrs determines where the function is instantiated. - virtual Status Instantiate( - const string& function_name, const FunctionLibraryDefinition& lib_def, - AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, - FunctionLibraryRuntime::LocalHandle* handle) = 0; + virtual Status Instantiate(const string& function_name, + const FunctionLibraryDefinition& lib_def, + AttrSlice attrs, + FunctionLibraryRuntime::LocalHandle* handle) = 0; // opts.runner isn't used for execution. virtual void Run(const FunctionLibraryRuntime::Options& opts, diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index facac10f66..f469f41e06 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -296,19 +296,21 @@ class RemoteCallOp : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { const Tensor* target; OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); + AttrValueMap attr_values = func_.attr(); + AttrValue v; const string& target_device = DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()()); + v.set_s(target_device); + AddAttr("_target", v, &attr_values); FunctionLibraryRuntime* lib = ctx->function_library(); OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library is provided."), done); - AttrValueMap attr_values = func_.attr(); FunctionLibraryRuntime::Handle handle; - OP_REQUIRES_OK_ASYNC(ctx, - lib->Instantiate(func_.name(), AttrSlice(&attr_values), - {target_device}, &handle), - done); + OP_REQUIRES_OK_ASYNC( + ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle), + done); OpInputList arguments; OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index 00c38eac10..909150eb6a 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -154,8 +154,7 @@ class OnSessionInitRequest(object): sess: A tensorflow Session object. """ - _check_type(sess, (session.SessionInterface, - monitored_session.MonitoredSession)) + _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) self.session = sess @@ -359,8 +358,7 @@ class BaseDebugWrapperSession(session.SessionInterface): NotImplementedError: If a non-DirectSession sess object is received. """ - _check_type(sess, (session.SessionInterface, - monitored_session.MonitoredSession)) + _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) # The session being wrapped. self._sess = sess diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py index 5240e0d0ad..73e08ce7d5 100644 --- a/tensorflow/python/debug/wrappers/framework_test.py +++ b/tensorflow/python/debug/wrappers/framework_test.py @@ -271,9 +271,9 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase): def testSessionInitInvalidSessionType(self): """Attempt to wrap a non-Session-type object should cause an exception.""" - sess = "not a session" + wrapper = TestDebugWrapperSessionBadAction(self._sess) with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"): - TestDebugWrapperSessionBadAction(sess) + TestDebugWrapperSessionBadAction(wrapper) def testSessionInitBadActionValue(self): with self.assertRaisesRegexp( |