/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" #include #include #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace { class FunctionLibraryRuntimeTest : public ::testing::Test { protected: void Init(const std::vector& flib, thread::ThreadPool* default_thread_pool = nullptr) { SessionOptions options; auto* device_count = options.config.mutable_device_count(); device_count->insert({"CPU", 3}); TF_CHECK_OK(DeviceFactory::AddDevices( options, "/job:localhost/replica:0/task:0", &devices_)); FunctionDefLibrary proto; for (const auto& fdef : flib) *(proto.add_function()) = fdef; lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; device_mgr_.reset(new DeviceMgr(devices_)); pflr_.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool, nullptr /* cluster_flr */)); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); fdef_lib_ = lib_def_->ToProto(); } Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, FunctionLibraryRuntime::Options opts, const std::vector& args, std::vector rets, bool add_runner = true) { std::atomic call_count(0); std::function)> runner = [&call_count](std::function fn) { ++call_count; test::function::FunctionTestSchedClosure(fn); }; if (add_runner) { opts.runner = &runner; } else { opts.runner = nullptr; } Notification done; std::vector out; Status status; flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) { status = s; done.Notify(); }); done.WaitForNotification(); if (!status.ok()) { return status; } CHECK_EQ(rets.size(), out.size()); for (size_t i = 0; i < rets.size(); ++i) { *rets[i] = out[i]; } if (add_runner) { EXPECT_GE(call_count, 1); // Test runner is used. } return Status::OK(); } Status Instantiate(FunctionLibraryRuntime* flr, const string& name, test::function::Attrs attrs, FunctionLibraryRuntime::Handle* handle) { return flr->Instantiate(name, attrs, handle); } Status Instantiate(FunctionLibraryRuntime* flr, const string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { return flr->Instantiate(name, attrs, options, handle); } Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name, test::function::Attrs attrs, const std::vector& args, std::vector rets, bool add_runner = true) { return InstantiateAndRun(flr, name, attrs, FunctionLibraryRuntime::InstantiateOptions(), args, std::move(rets), add_runner); } Status InstantiateAndRun( FunctionLibraryRuntime* flr, const string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, const std::vector& args, std::vector rets, bool add_runner = true) { FunctionLibraryRuntime::Handle handle; Status status = flr->Instantiate(name, attrs, options, &handle); if (!status.ok()) { return status; } FunctionLibraryRuntime::Options opts; status = Run(flr, handle, opts, args, rets, add_runner); if (!status.ok()) return status; // Release the handle and try running again. It should not succeed. status = flr->ReleaseHandle(handle); if (!status.ok()) return status; Status status2 = Run(flr, handle, opts, args, std::move(rets)); EXPECT_TRUE(errors::IsInvalidArgument(status2)); EXPECT_TRUE( str_util::StrContains(status2.error_message(), "remote execution.")); return status; } Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, FunctionLibraryRuntime::Options opts, CallFrameInterface* frame, bool add_runner = true) { std::atomic call_count(0); std::function)> runner = [&call_count](std::function fn) { ++call_count; test::function::FunctionTestSchedClosure(fn); }; if (add_runner) { opts.runner = &runner; } else { opts.runner = nullptr; } Notification done; std::vector out; Status status; flr->Run(opts, handle, frame, [&status, &done](const Status& s) { status = s; done.Notify(); }); done.WaitForNotification(); if (!status.ok()) { return status; } if (add_runner) { EXPECT_GE(call_count, 1); // Test runner is used. } return Status::OK(); } FunctionLibraryRuntime* flr0_; FunctionLibraryRuntime* flr1_; FunctionLibraryRuntime* flr2_; std::vector devices_; std::unique_ptr device_mgr_; std::unique_ptr lib_def_; std::unique_ptr pflr_; FunctionDefLibrary fdef_lib_; }; TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) { using test::function::blocking_op_state; using test::function::BlockingOpState; thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "FLRTest", 1); Init({test::function::BlockingOpFn(), test::function::XTimesTwo()}, tp); auto x = test::AsScalar(1.3); Tensor y; blocking_op_state = new BlockingOpState(); thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5); bool finished_running = false; tp1->Schedule([&x, &y, &finished_running, this]() { TF_CHECK_OK(InstantiateAndRun(flr0_, "BlockingOpFn", {}, {x}, {&y}, false /* add_runner */)); finished_running = true; }); // InstantiateAndRun shouldn't finish because BlockingOpFn should be blocked. EXPECT_FALSE(finished_running); FunctionLibraryRuntime::Handle h; TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, &h)); auto x1 = test::AsTensor({1, 2, 3, 4}); std::atomic num_done(0); FunctionLibraryRuntime::Options opts; for (int i = 0; i < 4; ++i) { tp1->Schedule([&h, &x1, &opts, &num_done, this]() { Tensor y1; TF_CHECK_OK(Run(flr0_, h, opts, {x1}, {&y1}, false /* add_runner */)); num_done.fetch_add(1); }); } // All the 4 Run() calls should be blocked because the runner is occupied. EXPECT_EQ(0, num_done.load()); blocking_op_state->AwaitState(1); blocking_op_state->MoveToState(1, 2); // Now the runner should be unblocked and all the other Run() calls should // proceed. blocking_op_state->AwaitState(3); blocking_op_state->MoveToState(3, 0); delete tp1; EXPECT_TRUE(finished_running); EXPECT_EQ(4, num_done.load()); delete blocking_op_state; blocking_op_state = nullptr; delete tp; } } // namespace } // namespace tensorflow