diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-28 14:08:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 14:34:43 -0700 |
commit | f83da5b0aa37ba55c1b2eaa093e6d043b73f5982 (patch) | |
tree | d28c727251f910dc0c7b7a6184286919d436e88f /tensorflow | |
parent | 1724d155f00b49bc817189247cbfb0df2092a9da (diff) |
Introduce the abstraction of RunHandler which each DirectSession can use for
the duration of a single RunInternal() call from RunHandlerPool. It is used for
running inter-op closures with a global scheduler (which in the future) to
improve both median and tail latency (for use-cases like CPU inference).
In the case that global pools aren't used, this change should be a no-op.
PiperOrigin-RevId: 214992852
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 49 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.h | 3 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session_test.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/framework/run_handler.cc | 249 | ||||
-rw-r--r-- | tensorflow/core/framework/run_handler.h | 95 | ||||
-rw-r--r-- | tensorflow/core/framework/run_handler_util.cc | 57 | ||||
-rw-r--r-- | tensorflow/core/framework/run_handler_util.h | 43 | ||||
-rw-r--r-- | tensorflow/core/framework/run_handler_util_test.cc | 93 | ||||
-rw-r--r-- | tensorflow/core/protobuf/config.proto | 5 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt | 6 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt | 6 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt | 6 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt | 6 |
14 files changed, 656 insertions, 6 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 50fe308b73..7da4b9fbd0 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2486,6 +2486,8 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests "framework/resource_var.h", + "framework/run_handler.h", + "framework/run_handler_util.h", "framework/tensor_reference.h", "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", @@ -2972,6 +2974,7 @@ tf_cuda_library( ":core_cpu_internal", ":device_tracer", ":framework", + ":framework_internal", ":graph", ":lib", ":lib_internal", @@ -4119,6 +4122,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "framework_run_handler_util_test", + size = "small", + srcs = ["framework/run_handler_util_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":framework_internal", + ":lib", + ":test", + ":test_main", + ], +) + tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 841181f8c3..458e133b68 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/run_handler.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" @@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool, #endif // __ANDROID__ } +static RunHandlerPool* GetOrCreateRunHandlerPool( + const SessionOptions& options) { + static RunHandlerPool* pool = + new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options)); + return pool; +} + +bool DirectSession::ShouldUseRunHandlerPool() const { + if (options_.config.session_inter_op_thread_pool_size() > 0 || + options_.config.use_per_session_threads()) { + return false; + } + return true; +} + DirectSession::DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, DirectSessionFactory* const factory) @@ -582,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, } } - Executor::Args::Runner default_runner = [this, - pool](Executor::Args::Closure c) { - SchedClosure(pool, std::move(c)); - }; + std::unique_ptr<RunHandler> handler; + if (ShouldUseRunHandlerPool() && + run_options.experimental().use_run_handler_pool()) { + // Non-null only when a global inter-op pool is used. + VLOG(1) << "Using RunHandler to scheduler inter-op closures."; + handler = GetOrCreateRunHandlerPool(options_)->Get(); + } + auto* handler_ptr = handler.get(); + + Executor::Args::Runner default_runner = nullptr; + + if (pool == nullptr) { + default_runner = [](Executor::Args::Closure c) { c(); }; + } else if (handler_ptr != nullptr) { + default_runner = [handler_ptr](Executor::Args::Closure c) { + handler_ptr->ScheduleInterOpClosure(std::move(c)); + }; + } else { + default_runner = [this, pool](Executor::Args::Closure c) { + SchedClosure(pool, std::move(c)); + }; + } + for (const auto& item : executors_and_keys->items) { - // TODO(zhengxq): support partial run. - // TODO(zhengxq): if the device picks its own threadpool, we need to assign + // TODO(azaks): support partial run. + // TODO(azaks): if the device picks its own threadpool, we need to assign // less threads to the main compute pool by default. thread::ThreadPool* device_thread_pool = item.device->tensorflow_device_thread_pool(); + // TODO(crk): Investigate usage of RunHandlerPool when using device specific + // thread pool(s). if (!device_thread_pool) { args.runner = default_runner; } else { diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 4a6a921ea7..3a168bbe3f 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -247,6 +247,9 @@ class DirectSession : public Session { ExecutorsAndKeys* executors_and_keys, RunMetadata* run_metadata); + // Returns whether inter-op execution uses a global pool. + bool ShouldUseRunHandlerPool() const; + ::tensorflow::Status ExtendLocked(const GraphDef& graph) EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 65e816c202..e3e431f800 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -625,6 +625,34 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) { EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2); } +TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) { + Initialize({3, 2, -1, 0}); + auto session = CreateSession(); + ASSERT_TRUE(session != nullptr); + TF_ASSERT_OK(session->Create(def_)); + std::vector<std::pair<string, Tensor>> inputs; + + // Request two targets: one fetch output and one non-fetched output. + std::vector<string> output_names = {y_ + ":0"}; + std::vector<string> target_nodes = {y_neg_}; + std::vector<Tensor> outputs; + + // Prepares RunOptions and RunMetadata + RunOptions run_options; + run_options.mutable_experimental()->set_use_run_handler_pool(true); + + Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, nullptr); + TF_ASSERT_OK(s); + + ASSERT_EQ(1, outputs.size()); + // The first output should be initialized and have the correct + // output. + auto mat = outputs[0].matrix<float>(); + ASSERT_TRUE(outputs[0].IsInitialized()); + EXPECT_FLOAT_EQ(5.0, mat(0, 0)); +} + TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) { GraphDef def; Graph g(OpRegistry::Global()); diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc new file mode 100644 index 0000000000..0c4007eafc --- /dev/null +++ b/tensorflow/core/framework/run_handler.cc @@ -0,0 +1,249 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/run_handler.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/run_handler_util.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +// Contains the concrete implementation of the RunHandler. +// Externally visible RunHandler class simply forwards the work to this one. +class RunHandler::Impl { + public: + explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) { + Reset(); + } + + ~Impl() {} + + void set_inter_op_scheduling_range(std::uint_fast32_t start, + std::uint_fast32_t limit) { + inter_op_scheduling_range_.store(EncodePartition(start, limit), + std::memory_order_release); + } + + std::uint_fast32_t inter_op_scheduling_range() const { + return inter_op_scheduling_range_.load(std::memory_order_acquire); + } + + // Stores now time (in microseconds) since unix epoch when the handler is + // requested via RunHandlerPool::Get(). + uint64 start_time_us() const { return start_time_us_; } + + void ScheduleInterOpClosure(std::function<void()> fn); + + void Reset(); + + RunHandlerPool::Impl* pool_impl() { return pool_impl_; } + + private: + // Encoding/decoding logic for storing [start, limit) into a single + // uint_fast32_t int. We assume that pool_num_threads < (1 << 16). + const int kMaxPartitionBits = 16; + const int kMaxThreads = 1 << kMaxPartitionBits; + + std::uint_fast32_t EncodePartition(std::uint_fast32_t start, + std::uint_fast32_t limit) { + return (start << kMaxPartitionBits) | limit; + } + + void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start, + std::uint_fast32_t* limit) { + *limit = val & (kMaxThreads - 1); + val >>= kMaxPartitionBits; + *start = val; + } + + std::atomic_uint_fast32_t inter_op_scheduling_range_; + RunHandlerPool::Impl* pool_impl_; // NOT OWNED. + uint64 start_time_us_; +}; + +// Contains shared state across all run handlers present in the pool. Also +// responsible for pool management decisions. +// This class is thread safe. +class RunHandlerPool::Impl { + public: + explicit Impl(int num_inter_op_threads) + : max_handlers_(128), + inter_op_thread_pool_(new thread::ThreadPool( + Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)), + iterations_(0) { + VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_; + for (int i = 0; i < max_handlers_; ++i) { + handlers_.emplace_back(new RunHandler::Impl(this)); + free_handlers_.push_back(handlers_.back().get()); + } + } + + ~Impl() { + // Sanity check that all handlers have been returned back to the pool before + // destruction. + DCHECK_EQ(handlers_.size(), max_handlers_); + DCHECK_EQ(free_handlers_.size(), handlers_.size()); + DCHECK_EQ(sorted_active_handlers_.size(), 0); + } + + thread::ThreadPool* inter_op_thread_pool() const { + return inter_op_thread_pool_.get(); + } + + std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + while (free_handlers_.empty()) { + one_handler_free_.wait(l); + } + // Remove the last entry from free_handlers_ and add to the end of + // sorted_active_handlers_. + auto* handler_impl = free_handlers_.back(); + handler_impl->Reset(); + // Sortedness isn't violated if we simply add at the end of the list, since + // handlers are expected to be obtained in increasing order of time. + sorted_active_handlers_.push_back(handler_impl); + DCHECK_LE(sorted_active_handlers_.size(), max_handlers_); + free_handlers_.pop_back(); + + RecomputePoolStatsLocked(); + return WrapUnique<RunHandler>(new RunHandler(handler_impl)); + } + + void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + DCHECK_GT(sorted_active_handlers_.size(), 0); + + uint64 now = tensorflow::Env::Default()->NowMicros(); + double elapsed = (now - handler->start_time_us()) / 1000.0; + time_hist_.Add(elapsed); + + // Erase from and update sorted_active_handlers_. Add it to the end of + // free_handlers_. + auto iter = std::find(sorted_active_handlers_.begin(), + sorted_active_handlers_.end(), handler); + DCHECK(iter != sorted_active_handlers_.end()) + << "Unexpected handler: " << handler + << " is being requested for release"; + + // Remove this handler from this list and add it to the list of free + // handlers. + sorted_active_handlers_.erase(iter); + free_handlers_.push_back(handler); + DCHECK_LE(free_handlers_.size(), max_handlers_); + + RecomputePoolStatsLocked(); + } + one_handler_free_.notify_one(); + } + + private: + void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Maximum number of handlers pre-created during pool construction time. The + // number has been chosen expecting each handler might at least want 1 + // inter-op thread for execution (during compute intensive workloads like + // inference). + const int max_handlers_; + + // Thread safe part. + const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_; + + // Thread compatible part used only by lock under RunHandlerPool. + // Handlers are sorted by start time. + std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_); + std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_); + std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_); + // Histogram of elapsed runtime of every handler (in ms). + histogram::Histogram time_hist_ GUARDED_BY(mu_); + std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_); + std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_); + int64 iterations_ GUARDED_BY(mu_); + condition_variable one_handler_free_; + mutex mu_; +}; + +void RunHandlerPool::Impl::RecomputePoolStatsLocked() { + int num_active_requests = sorted_active_handlers_.size(); + if (num_active_requests == 0) return; + + int num_threads = inter_op_thread_pool_->NumThreads(); + + inter_op_start_.resize(num_active_requests); + inter_op_limit_.resize(num_active_requests); + + const int kMinThreadsPerRequest = 3; + ComputeInterOpSchedulingRanges(num_active_requests, num_threads, + kMinThreadsPerRequest, &inter_op_start_, + &inter_op_limit_); + + for (int i = 0; i < num_active_requests; ++i) { + sorted_active_handlers_[i]->set_inter_op_scheduling_range( + inter_op_start_[i], inter_op_limit_[i]); + } + + if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) { + VLOG(1) << "Printing time histogram: " << time_hist_.ToString(); + VLOG(1) << "Active session runs: " << num_active_requests; + uint64 now = tensorflow::Env::Default()->NowMicros(); + string ranges_str = ""; + string times_str = ""; + for (int i = 0; i < num_active_requests; ++i) { + if (i > 0) { + times_str += " "; + ranges_str += " "; + } + + times_str += strings::StrCat( + (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms."); + ranges_str += strings::StrCat("[", inter_op_start_[i], ", ", + inter_op_limit_[i], ")"); + } + VLOG(1) << "Elapsed times are: " << times_str; + VLOG(1) << "Ranges are: " << ranges_str; + } +} + +void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) { + std::uint_fast32_t start = 0, limit = 0; + DecodePartition(inter_op_scheduling_range(), &start, &limit); + pool_impl_->inter_op_thread_pool()->Schedule(std::move(fn)); +} + +void RunHandler::Impl::Reset() { + set_inter_op_scheduling_range( + 0, pool_impl_->inter_op_thread_pool()->NumThreads()); + start_time_us_ = tensorflow::Env::Default()->NowMicros(); +} + +RunHandlerPool::RunHandlerPool(int num_inter_op_threads) + : impl_(new Impl(num_inter_op_threads)) {} + +RunHandlerPool::~RunHandlerPool() {} + +std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); } + +RunHandler::RunHandler(Impl* impl) : impl_(impl) {} + +void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) { + impl_->ScheduleInterOpClosure(std::move(fn)); +} + +RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); } +} // namespace tensorflow diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h new file mode 100644 index 0000000000..72fa6301b4 --- /dev/null +++ b/tensorflow/core/framework/run_handler.h @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class RunHandler; + +// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers +// that can be used for tracking inter-op work for a given Session::Run(). +// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes +// 'active' when its unique_ptr is returned by Get() and is being used by a +// client. It becomes 'inactive' once more when its unique_ptr gets destroyed. +// +// Expected usage: +// +// * Create a single RunHandlerPool (say run_handler_pool_). +// +// * When a Session::Run() is invoked, obtain a handler by: +// auto handler = run_handler_pool_->Get(); +// +// * Use handler for scheduling all inter-op work by: +// handler->ScheduleInterOpClosure(closure); +// +// This class is thread safe. +class RunHandlerPool { + public: + explicit RunHandlerPool(int num_inter_op_threads); + ~RunHandlerPool(); + + // Returns an inactive RunHandler from the pool. + // + // RunHandlers in RunHandlerPool are initially 'inactive'. + // A RunHandler becomes 'active' when its unique_ptr its returned by Get() + // and is being used by a client. It becomes 'inactive' once more when the + // unique_ptr is destroyed. + // + // Will block unless there is an inactive handler. + std::unique_ptr<RunHandler> Get(); + + private: + class Impl; + friend class RunHandler; + + std::unique_ptr<Impl> impl_; +}; + +// RunHandler can be used to schedule inter-op closures to run on a global pool +// shared across all Session::Run(s). +// +// It can only be created via RunHandlerPool::Get(). +// +// This class can be used instead of directly scheduling closures on a global +// pool since it maintains a global view across all sessions and optimizes pool +// scheduling to improve (median and tail) latency. +// +// This class is thread safe. +class RunHandler { + public: + void ScheduleInterOpClosure(std::function<void()> fn); + + ~RunHandler(); + + private: + class Impl; + friend class RunHandlerPool::Impl; + + explicit RunHandler(Impl* impl); + + Impl* impl_; // NOT OWNED. +}; + +} // end namespace tensorflow. + +#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc new file mode 100644 index 0000000000..3087998c69 --- /dev/null +++ b/tensorflow/core/framework/run_handler_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/run_handler_util.h" + +#include <algorithm> +#include <cmath> +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads, + int min_threads_per_request, + std::vector<std::uint_fast32_t>* start_vec, + std::vector<std::uint_fast32_t>* end_vec) { + // Each request is expected to have weight W[i] = num_active_requests - i. + // Therefore, total_weight = sum of all request weights. + float total_weight = 0.5f * num_active_requests * (num_active_requests + 1); + float demand_factor = static_cast<float>(num_threads) / total_weight; + float last_cumulative_weight = 0.0; + min_threads_per_request = std::max(1, min_threads_per_request); + for (int i = 0; i != num_active_requests; i++) { + float cumulative_weight = + static_cast<float>(i + 1) * + (num_active_requests - static_cast<float>(i) * 0.5f); + float weight = cumulative_weight - last_cumulative_weight; + // Quantize thread_demand by rounding up, and also satisfying + // `min_threads_per_request` constraint. + // Note: We subtract a small epsilon (0.00001) to prevent ceil(..) from + // rounding weights like 4.0 to 5. + int demand = + std::max(min_threads_per_request, + static_cast<int>(ceil(weight * demand_factor - 0.00001f))); + // For the quantized range [start, end); compute the floor of real start, + // and expand downwards from there with length `demand` and adjust for + // boundary conditions. + int start = last_cumulative_weight * demand_factor; + int end = std::min(num_threads, start + demand); + start = std::max(0, std::min(start, end - demand)); + start_vec->at(i) = start; + end_vec->at(i) = end; + last_cumulative_weight = cumulative_weight; + } +} +} // namespace tensorflow diff --git a/tensorflow/core/framework/run_handler_util.h b/tensorflow/core/framework/run_handler_util.h new file mode 100644 index 0000000000..c0c36aeccb --- /dev/null +++ b/tensorflow/core/framework/run_handler_util.h @@ -0,0 +1,43 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ + +#include <cstdint> +#include <vector> + +namespace tensorflow { + +// Assign thread ranges to requests. +// Requests are numbered 0...num_active_requests-1, and +// threads are numbered 0...num_threads-1. +// On return, the range start_vec->at(i)...end_vec->at(i)-1 +// indicates the subrange of the threads available to request i. +// The ranges given to different requests may overlap. +// Lower numbered requests will tend to be assigned more threads. +// Thus, a client might associate older requests with lower +// array indices so they receive access to more threads. +// However, the routine ensures that each request is given access +// to at least min(min_threads_per_request, num_threads) threads. +// Every thread will be assigned to at least one request range, +// assuming there is at least one request. +void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads, + int min_threads_per_request, + std::vector<std::uint_fast32_t>* start_vec, + std::vector<std::uint_fast32_t>* end_vec); + +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc new file mode 100644 index 0000000000..a1928c132b --- /dev/null +++ b/tensorflow/core/framework/run_handler_util_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/run_handler_util.h" + +#include <vector> +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +namespace tensorflow { +namespace { + +void VerifyFunction(int num_active_requests, int num_threads, + int min_threads_per_request, bool print_stats = false) { + if (print_stats) { + LOG(INFO) << "Test case# num_active_requests: " << num_active_requests + << " num_threads: " << num_threads + << " min_threads: " << min_threads_per_request; + } + std::vector<std::uint_fast32_t> start(num_active_requests); + std::vector<std::uint_fast32_t> end(num_active_requests); + + ComputeInterOpSchedulingRanges(num_active_requests, num_threads, + min_threads_per_request, &start, &end); + string range_str = ""; + for (int i = 0; i < num_active_requests; ++i) { + if (i > 0) range_str += " "; + range_str += strings::StrCat("[", start[i], ", ", end[i], ")"); + + ASSERT_GE(start[i], 0) << range_str; + ASSERT_LE(end[i], num_threads) << range_str; + if (i > 0) { + // Due to linearly decreasing demand, #threads(i - 1) >= #threads(i) + ASSERT_GE(end[i - 1] - start[i - 1], end[i] - start[i]) << range_str; + // No missing threads. + ASSERT_GE(end[i - 1], start[i]) << range_str; + } + // Each interval is at least of size 'min_threads_per_request'. + ASSERT_GE((end[i] - start[i]), min_threads_per_request) << range_str; + // Verify that assigned (quantized) threads is not overly estimated + // from real demand, when the demand is high (>= + // min_threads_per_request). + float entry_weight = num_active_requests - i; + float total_weight = 0.5f * num_active_requests * (num_active_requests + 1); + float thread_demand = (entry_weight * num_threads) / total_weight; + if (thread_demand > min_threads_per_request) { + // We expect some over-estimation of threads due to quantization, + // but we hope it's not more than 1 extra thread. + ASSERT_NEAR(end[i] - start[i], thread_demand, 1.0) + << "Ranges: " << range_str << " thread_demand: " << thread_demand + << " i: " << i; + } + } + ASSERT_EQ(end[num_active_requests - 1], num_threads); + ASSERT_EQ(start[0], 0); + if (print_stats) { + LOG(INFO) << "Assigned ranges: " << range_str; + } +} + +TEST(RunHandlerUtilTest, TestComputeInterOpSchedulingRanges) { + const int kMinThreadsPerRequestBound = 12; + const int kMaxActiveRequests = 128; + const int kMaxThreads = 128; + + for (int min_threads_per_request = 1; + min_threads_per_request <= kMinThreadsPerRequestBound; + ++min_threads_per_request) { + for (int num_active_requests = 1; num_active_requests <= kMaxActiveRequests; + ++num_active_requests) { + for (int num_threads = min_threads_per_request; + num_threads <= kMaxThreads; ++num_threads) { + VerifyFunction(num_active_requests, num_threads, + min_threads_per_request); + } + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 85cd02350a..104ab039cb 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -453,6 +453,11 @@ message RunOptions { // same group_key value (in a distributed computation where tasks // run disjoint graphs). int64 collective_graph_key = 1; + // If true, then operations (using the inter-op pool) across all + // session::run() calls will be centrally scheduled, optimizing for (median + // and tail) latency. + // Consider using this option for CPU-bound workloads like inference. + bool use_run_handler_pool = 2; }; Experimental experimental = 8; diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt index 537e73aa89..47b5b56faf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt @@ -8,5 +8,11 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT64 } + field { + name: "use_run_handler_pool" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt index cec04a2bf0..c0c2e7b9f8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt @@ -55,6 +55,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT64 } + field { + name: "use_run_handler_pool" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } } enum_type { name: "TraceLevel" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt index 537e73aa89..47b5b56faf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt @@ -8,5 +8,11 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT64 } + field { + name: "use_run_handler_pool" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt index cec04a2bf0..c0c2e7b9f8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt @@ -55,6 +55,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT64 } + field { + name: "use_run_handler_pool" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } } enum_type { name: "TraceLevel" |