diff options
author | Rohan Jain <rohanj@google.com> | 2017-08-17 11:33:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-17 11:36:50 -0700 |
commit | 935ff49201edd7a6297b313fb9545d1299b9a28d (patch) | |
tree | 36486015014d33efa99d7fd0875eb1545bd518cb /tensorflow/core | |
parent | d94dca2174f0c05dfa03796c3ae31d345813d025 (diff) |
Automated g4 rollback of changelist 165521057
PiperOrigin-RevId: 165604864
Diffstat (limited to 'tensorflow/core')
19 files changed, 156 insertions, 778 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index dbe759a6aa..3bff497975 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -808,14 +808,12 @@ cc_library( name = "testlib", testonly = 1, srcs = [ - "common_runtime/function_testlib.cc", "common_runtime/kernel_benchmark_testlib.cc", "framework/fake_input.cc", "framework/function_testlib.cc", "graph/testlib.cc", ], hdrs = [ - "common_runtime/function_testlib.h", "common_runtime/kernel_benchmark_testlib.h", "framework/fake_input.h", "framework/function_testlib.h", @@ -2655,14 +2653,17 @@ tf_cc_test( ":test_main", ":testlib", "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:functional_ops", - "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:dense_update_ops", + "//tensorflow/core/kernels:fifo_queue_op", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:shape_ops", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:queue_ops", + "//tensorflow/core/kernels:session_ops", + "//tensorflow/core/kernels:variable_ops", "//third_party/eigen3", ], ) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 829fba780f..6b529d8f13 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -139,14 +139,15 @@ static Node* AddRet(Graph* g, Endpoint input, int index) { return ret; } +static const FunctionLibraryRuntime::Handle kInvalidHandle = -1; + class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { public: FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator, - ProcessFunctionLibraryRuntime* parent); + CustomKernelCreator custom_kernel_creator); ~FunctionLibraryRuntimeImpl() override; @@ -183,13 +184,17 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const FunctionLibraryDefinition* const lib_def_; GraphOptimizer optimizer_; const CustomKernelCreator custom_kernel_creator_; - const string device_name_; std::function<Status(const string&, const OpDef**)> get_func_sig_; std::function<Status(const NodeDef&, OpKernel**)> create_kernel_; mutable mutex mu_; + // Maps function instantiation to a handle. The key is a + // canonicalized representation of the function name and + // instantiation attrs. The handle is an index into the items_. + std::unordered_map<string, Handle> table_ GUARDED_BY(mu_); + // func_graphs_ never shrinks or reorders its members. std::vector<FunctionBody*> func_graphs_ GUARDED_BY(mu_); @@ -203,15 +208,12 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { }; std::vector<Item*> items_; - ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned. - Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, FunctionBody** fbody); Status CreateItem(Handle handle, Item** item); Status GetOrCreateItem(Handle handle, Item** item); Status InstantiateSymbolicGradient(const NameAttrList& func, FunctionBody** g_body); - bool IsLocalTarget(const AttrSlice& attrs); TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); }; @@ -220,19 +222,14 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator, - ProcessFunctionLibraryRuntime* parent) + CustomKernelCreator custom_kernel_creator) : device_mgr_(dmgr), device_(device), env_(env), graph_def_version_(graph_def_version), lib_def_(lib_def), optimizer_(optimizer_options), - custom_kernel_creator_(std::move(custom_kernel_creator)), - device_name_(device_ == nullptr - ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice - : device_->name()), - parent_(parent) { + custom_kernel_creator_(std::move(custom_kernel_creator)) { get_func_sig_ = [this](const string& op, const OpDef** sig) { return lib_def_->LookUpOpDef(op, sig); }; @@ -297,17 +294,10 @@ class CallOp : public AsyncOpKernel { }; const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { - LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h); - if (local_handle == kInvalidLocalHandle) { - LOG(ERROR) << "Could not find Handle: " << h - << " on device: " << device_name_; - return nullptr; - } - mutex_lock l(mu_); - CHECK_LE(0, local_handle); - CHECK_LT(local_handle, func_graphs_.size()); - return func_graphs_[local_handle]; + CHECK_LE(static_cast<Handle>(0), h); + CHECK_LT(h, func_graphs_.size()); + return func_graphs_[h]; } Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, @@ -403,24 +393,16 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( return Status::OK(); } -bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) { - if (device_ == nullptr) return true; - string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); - if (target.empty()) return true; - return target == device_->name(); -} - Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) { - if (!IsLocalTarget(attrs)) { - return parent_->Instantiate(function_name, attrs, handle); - } - const string key = Canonicalize(function_name, attrs); - *handle = parent_->GetHandle(key); - if (*handle != kInvalidHandle) { - return Status::OK(); + { + mutex_lock l(mu_); + *handle = gtl::FindWithDefault(table_, key, kInvalidHandle); + if (*handle != kInvalidHandle) { + return Status::OK(); + } } Status s; @@ -449,11 +431,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, { mutex_lock l(mu_); - *handle = parent_->GetHandle(key); + *handle = gtl::FindWithDefault(table_, key, kInvalidHandle); if (*handle != kInvalidHandle) { delete fbody; } else { - *handle = parent_->AddHandle(key, device_name_, func_graphs_.size()); + *handle = func_graphs_.size(); + table_.insert({key, *handle}); func_graphs_.push_back(fbody); items_.resize(func_graphs_.size()); } @@ -511,14 +494,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { } Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { - LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); { mutex_lock l(mu_); - if (local_handle >= items_.size()) { + if (handle >= items_.size()) { return errors::NotFound("Function handle ", handle, " is not valid. Likely an internal error."); } - *item = items_[local_handle]; + *item = items_[handle]; if (*item != nullptr) { (*item)->Ref(); return Status::OK(); @@ -530,9 +512,9 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { { mutex_lock l(mu_); - if (items_[local_handle] == nullptr) { + if (items_[handle] == nullptr) { // Install *item in items_. - items_[local_handle] = *item; + items_[handle] = *item; (*item)->Ref(); } } @@ -546,9 +528,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { return done(errors::Cancelled("")); } - if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { - return parent_->Run(opts, handle, args, rets, done); - } const FunctionBody* fbody = GetFunctionBody(handle); FunctionCallFrame* frame = new FunctionCallFrame(fbody->arg_types, fbody->ret_types); @@ -637,21 +616,19 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator, - ProcessFunctionLibraryRuntime* parent) { + CustomKernelCreator custom_kernel_creator) { return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl( device_mgr, env, device, graph_def_version, lib_def, optimizer_options, - std::move(custom_kernel_creator), parent)); + std::move(custom_kernel_creator))); } std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - ProcessFunctionLibraryRuntime* parent) { + const OptimizerOptions& optimizer_options) { return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version, lib_def, optimizer_options, - GetCustomCreatorSingleton()->Get(), parent); + GetCustomCreatorSingleton()->Get()); } bool RemoveDeadNodes(Graph* g) { diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index 477340d87a..167f095597 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -37,6 +36,9 @@ static constexpr const char* const kNoInlineAttr = "_noinline"; // takes ownership of the returned OpKernel. // // TODO(zhifengc/phawkins): b/32379046 +typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&, + std::unique_ptr<OpKernel>*)> + CustomKernelCreator; void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb); // Creates a FunctionLibraryRuntime, which instantiates functions @@ -48,16 +50,11 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb); // The returned object does not take ownerships of "device" or // "lib_def". The caller must ensure "device" and "lib_def" outlives // the returned object. -// -// The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that -// typically owns the created FunctionLibraryRuntime object. The parent pointer -// is not owned by the FunctionLibraryRuntime object. std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator, - ProcessFunctionLibraryRuntime* parent); + CustomKernelCreator custom_kernel_creator); // Same as above except that the returned runtime consults with the // global default custom kernel creator registered by @@ -65,8 +62,7 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - ProcessFunctionLibraryRuntime* parent); + const OptimizerOptions& optimizer_options); // FunctionLibraryRuntime::GetFunctionBody returns a description of an // instantiated function that is represented as a Graph with arg/ret diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index a9f06c4df0..3ca4457b00 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include <atomic> -#include <utility> #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" @@ -25,7 +24,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" -#include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -36,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -50,18 +49,40 @@ Status GetOpSig(const string& op, const OpDef** sig) { return OpRegistry::Global()->LookUpOpDef(op, sig); } +void FunctionTestSchedClosure(std::function<void()> fn) { + static thread::ThreadPool* w = + new thread::ThreadPool(Env::Default(), "Test", 8); + w->Schedule(std::move(fn)); +} + void HasError(const Status& s, const string& substr) { EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) << s << ", expected substring " << substr; } +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; + class FunctionTest : public ::testing::Test { protected: FunctionTest() : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - void Create(const FunctionDef& fdef, test::function::Attrs attrs) { + void Create(const FunctionDef& fdef, Attrs attrs) { exec_ = nullptr; InstantiationResult result; TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); @@ -96,7 +117,7 @@ class FunctionTest : public ::testing::Test { TF_CHECK_OK(frame.SetArgs(args)); Executor::Args exec_args; exec_args.call_frame = &frame; - exec_args.runner = test::function::FunctionTestSchedClosure; + exec_args.runner = FunctionTestSchedClosure; TF_CHECK_OK(exec_->Run(exec_args)); std::vector<Tensor> computed; TF_CHECK_OK(frame.GetRetvals(&computed)); @@ -133,42 +154,41 @@ TEST_F(FunctionTest, WXPlusB) { class FunctionLibraryRuntimeTest : public ::testing::Test { protected: - void Init(const std::vector<FunctionDef>& flib) { - SessionOptions options; - auto* device_count = options.config.mutable_device_count(); - device_count->insert({"CPU", 3}); - TF_CHECK_OK(DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices_)); + FunctionLibraryRuntimeTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")) {} + void Init(const std::vector<FunctionDef>& flib) { FunctionDefLibrary proto; for (const auto& fdef : flib) *(proto.add_function()) = fdef; lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; - device_mgr_.reset(new DeviceMgr(devices_)); - pflr_.reset(new ProcessFunctionLibraryRuntime( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts)); - flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); - flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); - flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); + lib_ = + NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(), + TF_GRAPH_DEF_VERSION, lib_def_.get(), opts); fdef_lib_ = lib_def_->ToProto(); } - Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, - const std::vector<Tensor>& args, std::vector<Tensor*> rets) { + Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args, + std::vector<Tensor*> rets) { + FunctionLibraryRuntime::Handle handle; + Status status = lib_->Instantiate(name, attrs, &handle); + if (!status.ok()) { + return status; + } + std::atomic<int32> call_count(0); std::function<void(std::function<void()>)> runner = [&call_count](std::function<void()> fn) { ++call_count; - test::function::FunctionTestSchedClosure(fn); + FunctionTestSchedClosure(fn); }; Notification done; FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector<Tensor> out; - Status status; - flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) { + lib_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { status = s; done.Notify(); }); @@ -186,54 +206,28 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return Status::OK(); } - Status Instantiate(FunctionLibraryRuntime* flr, const string& name, - test::function::Attrs attrs, - FunctionLibraryRuntime::Handle* handle) { - Status status = flr->Instantiate(name, attrs, handle); - if (!status.ok()) { - return status; - } - return Status::OK(); - } - - Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name, - test::function::Attrs attrs, - const std::vector<Tensor>& args, - std::vector<Tensor*> rets) { - FunctionLibraryRuntime::Handle handle; - Status status = flr->Instantiate(name, attrs, &handle); - if (!status.ok()) { - return status; - } - return Run(flr, handle, args, std::move(rets)); - } - - std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr, - const string& name, - test::function::Attrs attrs) { + std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) { FunctionLibraryRuntime::Handle handle; - Status status = flr->Instantiate(name, attrs, &handle); + Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { LOG(ERROR) << status; return nullptr; } - const FunctionBody* fbody = flr->GetFunctionBody(handle); + const FunctionBody* fbody = lib_->GetFunctionBody(handle); CHECK_NOTNULL(fbody); std::unique_ptr<Graph> ret(new Graph(lib_def_.get())); CopyGraph(*fbody->graph, ret.get()); return ret; } - std::unique_ptr<Graph> GetGradBody(FunctionLibraryRuntime* flr, - const string& func, - test::function::Attrs attrs) { + std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) { FunctionLibraryRuntime::Handle handle; - Status status = flr->Instantiate(func, attrs, &handle); + Status status = lib_->Instantiate(func, attrs, &handle); if (!status.ok()) { LOG(ERROR) << status; return nullptr; } - const FunctionBody* fbody = flr->GetFunctionBody(handle); + const FunctionBody* fbody = lib_->GetFunctionBody(handle); CHECK_NOTNULL(fbody); std::unique_ptr<FunctionBody> gbody(SymbolicGradient(*fbody)); CHECK_NOTNULL(gbody); @@ -242,29 +236,24 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return ret; } - FunctionLibraryRuntime* flr0_; - FunctionLibraryRuntime* flr1_; - FunctionLibraryRuntime* flr2_; - std::vector<Device*> devices_; - std::unique_ptr<DeviceMgr> device_mgr_; + std::unique_ptr<Device> device_; std::unique_ptr<FunctionLibraryDefinition> lib_def_; - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + std::unique_ptr<FunctionLibraryRuntime> lib_; FunctionDefLibrary fdef_lib_; }; TEST_F(FunctionLibraryRuntimeTest, IsStateful) { Init({}); - EXPECT_TRUE(flr0_->IsStateful("Variable")); - EXPECT_TRUE(flr0_->IsStateful("VariableV2")); - EXPECT_FALSE(flr0_->IsStateful("Matmul")); + EXPECT_TRUE(lib_->IsStateful("Variable")); + EXPECT_TRUE(lib_->IsStateful("VariableV2")); + EXPECT_FALSE(lib_->IsStateful("Matmul")); } TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) { Init({test::function::XTimesTwo()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK( - InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); } @@ -273,14 +262,11 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesN) { test::function::XTimes16()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK( - InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); - TF_CHECK_OK( - InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK(Run("XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16})); - TF_CHECK_OK( - InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64})); } @@ -308,7 +294,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name, TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); { @@ -326,7 +312,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -348,7 +334,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); GraphDef e2; { Scope s = Scope::NewRootScope(); @@ -387,7 +373,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { } // No further inlining. - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { GraphDef actual; g->ToGraphDef(&actual); @@ -439,7 +425,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TF_ASSERT_OK(s.ToGraph(g.get())); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -463,7 +449,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -509,10 +495,10 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); - ExpandInlineFunctions(flr0_, g.get()); - OptimizeGraph(flr0_, &g); + ExpandInlineFunctions(lib_.get(), g.get()); + OptimizeGraph(lib_.get(), &g); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -555,9 +541,9 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { // Return {{"o", "g:output"}}); Init({test::function::Swap(), func}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsNodeDef", {}); + std::unique_ptr<Graph> g = GetFuncBody("ManySwapsNodeDef", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(flr0_, &g); + OptimizeGraph(lib_.get(), &g); const char* e0 = R"P( (n3:float, n2:float) -> (n3:float) { } @@ -588,9 +574,9 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}}, {{"o", "o:z:0"}}); Init({test::function::Swap(), func}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsFirst", {}); + std::unique_ptr<Graph> g = GetFuncBody("ManySwapsFirst", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(flr0_, &g); + OptimizeGraph(lib_.get(), &g); // NOTE: We can remove func0, func1, func2, func9 with a control edge n8->n5. // But we don't have a pass doing that. @@ -623,7 +609,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { Init({test::function::XTimesTwo(), test::function::XTimesFour()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - HasError(InstantiateAndRun(flr0_, "Foo", {{"T", DT_FLOAT}}, {x}, {&y}), + HasError(Run("Foo", {{"T", DT_FLOAT}}, {x}, {&y}), "Not found: Function Foo is not defined."); } @@ -646,27 +632,25 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { // Instantiating "XTimesTwo" should fail. FunctionLibraryRuntime::Handle handle; - HasError(flr0_->Instantiate( - "XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle), + HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle), "Not found: type attr not found"); // But XTimesFour and XTimes16 instantiation should succeed. Only // when they run, they fail because XTimesTwo is bad. - TF_CHECK_OK(flr0_->Instantiate( - "XTimesFour", test::function::Attrs({{"T", DT_FLOAT}}), &handle)); - TF_CHECK_OK(flr0_->Instantiate( - "XTimes16", test::function::Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK( + lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle)); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - HasError(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), + HasError(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), "type attr not found"); } TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> f = GetFuncBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -682,7 +666,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { TF_EXPECT_GRAPH_EQ(expected, actual); } - std::unique_ptr<Graph> g = GetGradBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}); { Scope s = Scope::NewRootScope(); @@ -706,7 +690,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { TF_EXPECT_GRAPH_EQ(expected, actual); } - OptimizeGraph(flr0_, &g); + OptimizeGraph(lib_.get(), &g); { Scope s = Scope::NewRootScope(); @@ -742,7 +726,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) { Init({}); auto T = DT_FLOAT; std::unique_ptr<Graph> g = GetFuncBody( - flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); + "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -772,7 +756,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) { Init({}); auto T = DT_FLOAT; std::unique_ptr<Graph> g = GetFuncBody( - flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); + "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -828,7 +812,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { Init({test, grad}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "TestGrad", {}); + std::unique_ptr<Graph> g = GetFuncBody("TestGrad", {}); ASSERT_TRUE(g != nullptr); { Scope s = Scope::NewRootScope(); @@ -852,7 +836,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -904,7 +888,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TF_EXPECT_GRAPH_EQ(expected, actual); } - OptimizeGraph(flr0_, &g); + OptimizeGraph(lib_.get(), &g); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -955,25 +939,6 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { } } -TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { - Init({test::function::FindDevice()}); - FunctionLibraryRuntime::Handle handle; - TF_CHECK_OK(Instantiate( - flr0_, "FindDevice", - {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); - - Tensor y; - // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. - TF_CHECK_OK(Run(flr1_, handle, {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, - TensorShape({}))); - TF_CHECK_OK(Run(flr2_, handle, {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, - TensorShape({}))); -} - namespace { bool DoNothing(Graph* g) { return false; } diff --git a/tensorflow/core/common_runtime/function_testlib.cc b/tensorflow/core/common_runtime/function_testlib.cc deleted file mode 100644 index 64e59762a2..0000000000 --- a/tensorflow/core/common_runtime/function_testlib.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/common_runtime/function_testlib.h" - -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace test { -namespace function { - -typedef FunctionDefHelper FDH; - -class FindDeviceOpKernel : public OpKernel { - public: - explicit FindDeviceOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* ctx) override { - Tensor* device_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output("device_name", TensorShape{}, - &device_tensor)); - device_tensor->scalar<string>()() = - ctx->function_library()->device()->name(); - } -}; - -REGISTER_KERNEL_BUILDER(Name("FindDeviceOp").Device(tensorflow::DEVICE_CPU), - FindDeviceOpKernel); -REGISTER_OP("FindDeviceOp").Output("device_name: string"); - -FunctionDef FindDevice() { - return FDH::Define( - // Name - "FindDevice", - // Args - {}, - // Return values - {"device_name: string"}, - // Attr def - {}, - // Nodes - {{{"device_name"}, "FindDeviceOp", {}, {}}}); -} - -} // namespace function -} // namespace test -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_testlib.h b/tensorflow/core/common_runtime/function_testlib.h deleted file mode 100644 index 6b93b188b7..0000000000 --- a/tensorflow/core/common_runtime/function_testlib.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ - -#include "tensorflow/core/framework/function.h" - -namespace tensorflow { -namespace test { -namespace function { - -// {} -> y:DT_STRING (device where this op runs). -FunctionDef FindDevice(); - -} // namespace function -} // namespace test -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 0caec03625..97d891fa16 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -14,56 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/process_function_library_runtime.h" -#include <utility> - #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/lib/gtl/map_util.h" namespace tensorflow { -const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null"; - ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options) { - if (device_mgr == nullptr) { - flr_map_[kDefaultFLRDevice] = - NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version, - lib_def, optimizer_options, this); - return; - } - for (Device* d : device_mgr->ListDevices()) { - flr_map_[d->name()] = - NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version, - lib_def, optimizer_options, this); - } -} - -ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator) { - if (device_mgr == nullptr) { - flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime( - nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options, - custom_kernel_creator, this); - } + if (!device_mgr) return; for (Device* d : device_mgr->ListDevices()) { flr_map_[d->name()] = NewFunctionLibraryRuntime( - device_mgr, env, d, graph_def_version, lib_def, optimizer_options, - custom_kernel_creator, this); - } -} - -string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( - const AttrSlice& attrs) { - const AttrValue* value; - if (!attrs.Find("_target", &value).ok()) { - return ""; + device_mgr, env, d, graph_def_version, lib_def, optimizer_options); } - return value->s(); } FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( @@ -75,70 +38,4 @@ FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( return flr_map_[device_name].get(); } -FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle( - const string& function_key, const string& device_name, - FunctionLibraryRuntime::LocalHandle local_handle) { - mutex_lock l(mu_); - FunctionLibraryRuntime::Handle h = - gtl::FindWithDefault(table_, function_key, kInvalidHandle); - if (h != kInvalidHandle) { - return h; - } - h = function_data_.size(); - function_data_.emplace_back(device_name, local_handle); - table_[function_key] = h; - return h; -} - -FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle( - const string& function_key) const { - mutex_lock l(mu_); - return gtl::FindWithDefault(table_, function_key, kInvalidHandle); -} - -bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice( - const string& device_name, FunctionLibraryRuntime::Handle handle) { - return GetHandleOnDevice(device_name, handle) != -1; -} - -FunctionLibraryRuntime::LocalHandle -ProcessFunctionLibraryRuntime::GetHandleOnDevice( - const string& device_name, FunctionLibraryRuntime::Handle handle) { - mutex_lock l(mu_); - std::pair<string, FunctionLibraryRuntime::LocalHandle> p = - function_data_[handle]; - if (p.first != device_name) { - return kInvalidLocalHandle; - } - return p.second; -} - -Status ProcessFunctionLibraryRuntime::Instantiate( - const string& function_name, AttrSlice attrs, - FunctionLibraryRuntime::Handle* handle) { - string target = ObtainFunctionTarget(attrs); - - FunctionLibraryRuntime* flr = GetFLR(target); - if (flr != nullptr) { - return flr->Instantiate(function_name, attrs, handle); - } - return errors::InvalidArgument("Target: ", target, " is not supported"); -} - -void ProcessFunctionLibraryRuntime::Run( - const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, - std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) { - FunctionLibraryRuntime* flr = nullptr; - { - mutex_lock l(mu_); - std::pair<string, FunctionLibraryRuntime::LocalHandle> p = - function_data_[handle]; - flr = GetFLR(p.first); - } - if (flr != nullptr) { - return flr->Run(opts, handle, args, rets, std::move(done)); - } -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 2259997005..53b2223b28 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -24,6 +24,7 @@ limitations under the License. namespace tensorflow { // A class that stores all the FunctionLibraryRuntime objects, one per device. +// This class is not thread safe. class ProcessFunctionLibraryRuntime { public: // Creates FunctionLibraryRuntime objects for each device in the provided @@ -34,64 +35,10 @@ class ProcessFunctionLibraryRuntime { const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options); - ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, - int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator); - - // Given a list of attrs on a function, extracts the "_target" attribute which - // indicates which device to run the function on. If it can't find the _target - // attribute, returns "". Canonicalizes the device name. - static string ObtainFunctionTarget(const AttrSlice& attrs); - - static const char kDefaultFLRDevice[]; // Returns the FunctionLibraryRuntime for the corresponding device_name. FunctionLibraryRuntime* GetFLR(const string& device_name); - // For a given canonicalized key signature of the function instantiated - // on device `device_name` and a `local_handle`, creates a handle and returns - // that value. Use core/common_runtime/framework/function.h::Canonicalize - // to canonicalize the function signature. - FunctionLibraryRuntime::Handle AddHandle( - const string& function_key, const string& device_name, - FunctionLibraryRuntime::LocalHandle local_handle); - - // Returns a handle if found for the given key, else returns kInvalidHandle. - FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; - - // For the given handle instantiated on device `device_name` returns the local - // index of instantiation of that function. If the function was not - // instantiated on `device_name` returns kInvalidLocalHandle. - FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( - const string& device_name, FunctionLibraryRuntime::Handle handle); - - // Returns true if function with handle `handle` was instantiated on device - // `device_name`. - bool IsInstantiatedOnDevice(const string& device_name, - FunctionLibraryRuntime::Handle handle); - - // Instantiates the function. See framework/function.h for more details. - // Allows for function_name to be instantiated on different devices - // as specified in attrs. - Status Instantiate(const string& function_name, AttrSlice attrs, - FunctionLibraryRuntime::Handle* handle); - - // Runs the function with given `handle`. Function could have been - // instantiated on any device. More details in framework/function.h - void Run(const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, - std::vector<Tensor>* rets, - FunctionLibraryRuntime::DoneCallback done); - private: - mutable mutex mu_; - - // Holds all the function invocations here. - std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ - GUARDED_BY(mu_); - std::vector<std::pair<string, FunctionLibraryRuntime::LocalHandle>> - function_data_ GUARDED_BY(mu_); std::unordered_map<string, std::unique_ptr<FunctionLibraryRuntime>> flr_map_; }; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 1536aedde5..d9a5cab88b 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -17,9 +17,6 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/function_testlib.h" -#include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -28,8 +25,8 @@ namespace tensorflow { namespace { class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { - protected: - void Init(const std::vector<FunctionDef>& flib) { + public: + ProcessFunctionLibraryRuntimeTest() { SessionOptions options; auto* device_count = options.config.mutable_device_count(); device_count->insert({"CPU", 2}); @@ -37,7 +34,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { &devices_)); device_mgr_.reset(new DeviceMgr(devices_)); FunctionDefLibrary proto; - for (const auto& fdef : flib) *(proto.add_function()) = fdef; lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; proc_flr_.reset(new ProcessFunctionLibraryRuntime( @@ -45,43 +41,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { opts)); } - Status Run(const string& name, test::function::Attrs attrs, - const std::vector<Tensor>& args, std::vector<Tensor*> rets) { - FunctionLibraryRuntime::Handle handle; - Status status = proc_flr_->Instantiate(name, attrs, &handle); - if (!status.ok()) { - return status; - } - - std::atomic<int32> call_count(0); - std::function<void(std::function<void()>)> runner = - [&call_count](std::function<void()> fn) { - ++call_count; - test::function::FunctionTestSchedClosure(fn); - }; - - Notification done; - FunctionLibraryRuntime::Options opts; - opts.runner = &runner; - std::vector<Tensor> out; - proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { - status = s; - done.Notify(); - }); - done.WaitForNotification(); - if (!status.ok()) { - return status; - } - CHECK_EQ(rets.size(), out.size()); - for (size_t i = 0; i < rets.size(); ++i) { - *rets[i] = out[i]; - } - - EXPECT_GE(call_count, 1); // Test runner is used. - - return Status::OK(); - } - + protected: std::vector<Device*> devices_; std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<FunctionLibraryDefinition> lib_def_; @@ -89,7 +49,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { }; TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { - Init({}); FunctionLibraryRuntime* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0"); EXPECT_NE(flr, nullptr); @@ -101,87 +60,5 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { EXPECT_EQ(flr, nullptr); } -TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { - AttrSlice empty_attrs; - string target = - ProcessFunctionLibraryRuntime::ObtainFunctionTarget(empty_attrs); - EXPECT_EQ("", target); - - AttrValueMap attr_values; - AttrValue v; - v.set_s("/job:a/replica:0/task:0/cpu:1"); - AddAttr("_target", v, &attr_values); - AttrSlice attrs(&attr_values); - target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); - EXPECT_EQ("/job:a/replica:0/task:0/cpu:1", target); -} - -TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { - Init({test::function::XTimesTwo()}); - auto x = test::AsTensor<float>({1, 2, 3, 4}); - Tensor y; - TF_CHECK_OK( - Run("XTimesTwo", - {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, - {&y})); - test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); -} - -TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { - Init({test::function::FindDevice()}); - Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, - {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"}, - TensorShape({}))); -} - -TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { - Init({test::function::XTimesTwo(), test::function::XTimesFour()}); - auto x = test::AsTensor<float>({1, 2, 3, 4}); - Tensor y; - TF_CHECK_OK( - Run("XTimesTwo", - {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, - {&y})); - test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); - TF_CHECK_OK( - Run("XTimesFour", - {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, - {&y})); - test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16})); -} - -TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { - Init({test::function::FindDevice()}); - Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"}, - TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"}, - TensorShape({}))); -} - -TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { - Init({test::function::FindDevice()}); - Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, - {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"}, - TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"}, - TensorShape({}))); -} - } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 717f0c8575..045976dd06 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -437,16 +437,8 @@ class FunctionLibraryRuntime { // Returns the graph version number. virtual int graph_def_version() = 0; - - typedef uint64 LocalHandle; }; -const FunctionLibraryRuntime::Handle kInvalidHandle = -1; -const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; -typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&, - std::unique_ptr<OpKernel>*)> - CustomKernelCreator; - // To register a gradient function for a builtin op, one should use // REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>); // diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index e6ef8425fb..4ee23226da 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -173,12 +172,6 @@ FunctionDef Swap() { {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); } -void FunctionTestSchedClosure(std::function<void()> fn) { - static thread::ThreadPool* w = - new thread::ThreadPool(Env::Default(), "Test", 8); - w->Schedule(std::move(fn)); -} - } // end namespace function } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index a742fe0ce7..49e5b0c99d 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -30,22 +30,6 @@ namespace tensorflow { namespace test { namespace function { -// A helper class to make AttrSlice from initializer lists -class Attrs { - public: - Attrs(const std::initializer_list< // NOLINT(runtime/explicit) - std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) { - for (const auto& aval : attrs) { - map_.insert({aval.first, aval.second.proto}); - } - } - - operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) - - private: - AttrValueMap map_; -}; - // Helper to construct a NodeDef. NodeDef NDef( const string& name, const string& op, gtl::ArraySlice<string> inputs, @@ -78,8 +62,6 @@ FunctionDef NonZero(); // x:T, y:T -> y:T, x:T FunctionDef Swap(); -void FunctionTestSchedClosure(std::function<void()> fn); - } // end namespace function } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index b740e8a999..6136651410 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -104,11 +104,9 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, optimizer_opts->set_do_function_inlining(cfg.inline_functions); // Create the function library runtime. - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( - new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, - inlined_graph_def.versions().producer(), - &function_library, *optimizer_opts)); - FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name()); + std::unique_ptr<FunctionLibraryRuntime> flib(NewFunctionLibraryRuntime( + dvc_mgr.get(), env, devices[0], inlined_graph_def.versions().producer(), + &function_library, *optimizer_opts)); // Create the GraphOptimizer to optimize the graph def. GraphConstructorOptions graph_ctor_opts; @@ -124,7 +122,8 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, // Optimize the graph. GraphOptimizer optimizer(*optimizer_opts); - optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr); + optimizer.Optimize(flib.get(), env, devices[0], &graphptr, + /*shape_map=*/nullptr); graphptr->ToGraphDef(output_graph_def); return Status::OK(); diff --git a/tensorflow/core/kernels/captured_function.cc b/tensorflow/core/kernels/captured_function.cc index 15e9680f26..eb52de6d85 100644 --- a/tensorflow/core/kernels/captured_function.cc +++ b/tensorflow/core/kernels/captured_function.cc @@ -40,9 +40,9 @@ Status CapturedFunction::Create( // NOTE(mrry): We need to assign a name to the device, and we choose // the same name as the calling context's device so that we do not // need to rewrite resource handles that are found in `captured_inputs`. - Device* device = - new ThreadPoolDevice(SessionOptions(), ctx->device()->attributes().name(), - Bytes(256 << 20), DeviceLocality(), cpu_allocator()); + std::unique_ptr<Device> device(new ThreadPoolDevice( + SessionOptions(), ctx->device()->attributes().name(), Bytes(256 << 20), + DeviceLocality(), cpu_allocator())); // TODO(mrry): Handle arbitrary resource types, which might require a // redesign (or opening up access to `ResourceMgr::DoLookup()` and @@ -82,24 +82,20 @@ Status CapturedFunction::Create( } #undef HANDLE_RESOURCE_TYPE - std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr({device})); std::unique_ptr<FunctionLibraryDefinition> flib_def( new FunctionLibraryDefinition( *ctx->function_library()->GetFunctionLibraryDefinition())); - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( - new ProcessFunctionLibraryRuntime( - device_mgr.get(), ctx->env(), graph_def_version, flib_def.get(), - {} /* TODO(mrry): OptimizerOptions? */)); - - FunctionLibraryRuntime* lib = pflr->GetFLR(device->name()); + std::unique_ptr<FunctionLibraryRuntime> lib(NewFunctionLibraryRuntime( + nullptr /* device_mgr */, ctx->env(), device.get(), graph_def_version, + flib_def.get(), {} /* TODO(mrry): OptimizerOptions? */)); FunctionLibraryRuntime::Handle f_handle; TF_RETURN_IF_ERROR( lib->Instantiate(func->name(), AttrSlice(&func->attr()), &f_handle)); out_function->reset(new CapturedFunction( - device, std::move(device_mgr), std::move(flib_def), std::move(pflr), lib, - f_handle, std::move(captured_inputs))); + std::move(device), std::move(flib_def), std::move(lib), f_handle, + std::move(captured_inputs))); return Status::OK(); } @@ -140,16 +136,14 @@ Status CapturedFunction::Run(FunctionLibraryRuntime::Options f_opts, } CapturedFunction::CapturedFunction( - Device* device, std::unique_ptr<DeviceMgr> device_mgr, + std::unique_ptr<Device> device, std::unique_ptr<FunctionLibraryDefinition> flib_def, - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, - FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, + std::unique_ptr<FunctionLibraryRuntime> lib, + FunctionLibraryRuntime::Handle f_handle, std::vector<Tensor> captured_inputs) - : device_(device), - device_mgr_(std::move(device_mgr)), + : device_(std::move(device)), flib_def_(std::move(flib_def)), - pflr_(std::move(pflr)), - lib_(lib), + lib_(std::move(lib)), f_handle_(f_handle), captured_inputs_(std::move(captured_inputs)) {} diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/core/kernels/captured_function.h index 03679736f3..e24bcb9d82 100644 --- a/tensorflow/core/kernels/captured_function.h +++ b/tensorflow/core/kernels/captured_function.h @@ -63,23 +63,20 @@ class CapturedFunction { gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, const string& prefix); - const Device* device() const { return device_; } + Device* device() const { return device_.get(); } ResourceMgr* resource_manager() const { return device_->resource_manager(); } private: - CapturedFunction(Device* device, std::unique_ptr<DeviceMgr> device_mgr, + CapturedFunction(std::unique_ptr<Device> device, std::unique_ptr<FunctionLibraryDefinition> flib_def, - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, - FunctionLibraryRuntime* lib, + std::unique_ptr<FunctionLibraryRuntime> lib, FunctionLibraryRuntime::Handle f_handle, std::vector<Tensor> captured_inputs); - Device* const device_; // owned by device_mgr_. - const std::unique_ptr<DeviceMgr> device_mgr_; + const std::unique_ptr<Device> device_; const std::unique_ptr<FunctionLibraryDefinition> flib_def_; - const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; - FunctionLibraryRuntime* const lib_; // owned by pflr_. + const std::unique_ptr<FunctionLibraryRuntime> lib_; const FunctionLibraryRuntime::Handle f_handle_; const std::vector<Tensor> captured_inputs_; diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index a1dfd4c3d3..b831b5bff5 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -278,66 +278,4 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL), SymbolicGradientOp); #endif // TENSORFLOW_USE_SYCL - -class RemoteCallOp : public AsyncOpKernel { - public: - explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); - } - - ~RemoteCallOp() override {} - - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - const Tensor* target; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); - AttrValueMap attr_values = func_->attr(); - AttrValue v; - v.set_s(target->scalar<string>()()); - AddAttr("_target", v, &attr_values); - - FunctionLibraryRuntime* lib = ctx->function_library(); - OP_REQUIRES_ASYNC(ctx, lib != nullptr, - errors::Internal("No function library is provided."), - done); - FunctionLibraryRuntime::Handle handle; - OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(func_->name(), AttrSlice(&attr_values), &handle), - done); - - OpInputList arguments; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); - - FunctionLibraryRuntime::Options opts; - opts.step_id = ctx->step_id(); - opts.runner = ctx->runner(); - std::vector<Tensor> args; - args.reserve(arguments.size()); - for (const Tensor& argument : arguments) { - args.push_back(argument); - } - auto* rets = new std::vector<Tensor>; - lib->Run(opts, handle, args, rets, [rets, done, ctx](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } - for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } - delete rets; - done(); - }); - } - - private: - string target_; - const NameAttrList* func_; - TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); -}; - -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp); -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp); -#if TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp); - -#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 6ff1a3fc03..e76573ffdb 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -19997,37 +19997,6 @@ op { } } op { - name: "RemoteCall" - input_arg { - name: "target" - type: DT_STRING - } - input_arg { - name: "args" - type_list_attr: "Tin" - } - output_arg { - name: "output" - type_list_attr: "Tout" - } - attr { - name: "Tin" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "Tout" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "f" - type: "func" - } -} -op { name: "RemoteFusedGraphExecute" input_arg { name: "inputs" diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 5fd21ec88f..d1f9e94942 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -66,22 +65,4 @@ to x_i. (Needs some math expert to say the comment above better.) )doc"); -REGISTER_OP("RemoteCall") - .Input("target: string") - .Input("args: Tin") - .Output("output: Tout") - .Attr("Tin: list(type)") - .Attr("Tout: list(type)") - .Attr("f: func") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Runs function `f` on a remote device indicated by `target`. - -target: A fully specified device name where we want to run the function. -args: A list of arguments for the function. -output: A list of return values. -Tin: The type list for the arguments. -Tout: The type list for the return values. -f: The function to run remotely. -)doc"); } // end namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 87cdc30fb1..06eabdcdcd 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -19608,44 +19608,6 @@ op { summary: "Computes rectified linear gradients for a Relu operation." } op { - name: "RemoteCall" - input_arg { - name: "target" - description: "A fully specified device name where we want to run the function." - type: DT_STRING - } - input_arg { - name: "args" - description: "A list of arguments for the function." - type_list_attr: "Tin" - } - output_arg { - name: "output" - description: "A list of return values." - type_list_attr: "Tout" - } - attr { - name: "Tin" - type: "list(type)" - description: "The type list for the arguments." - has_minimum: true - minimum: 1 - } - attr { - name: "Tout" - type: "list(type)" - description: "The type list for the return values." - has_minimum: true - minimum: 1 - } - attr { - name: "f" - type: "func" - description: "The function to run remotely." - } - summary: "Runs function `f` on a remote device indicated by `target`." -} -op { name: "RemoteFusedGraphExecute" input_arg { name: "inputs" |