aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/run_handler_util_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 12:37:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 12:41:43 -0700
commit750466c6e6624d279de7f9a43accd682d487509c (patch)
treea97a88d432433b3c55775f64bb7a5f86a9f034b2 /tensorflow/core/framework/run_handler_util_test.cc
parent561a3c4331ebfaac3e61c524911bf6fe85f4ebc9 (diff)
Introduce the abstraction of RunHandler which each DirectSession can use for
the duration of a single RunInternal() call from RunHandlerPool. We want to leverage this abstraction for improving the cross-session inter-op parallelism for lower latency inference in the future. In the case that global pools aren't used, this change should be a no-op. PiperOrigin-RevId: 214818187
Diffstat (limited to 'tensorflow/core/framework/run_handler_util_test.cc')
-rw-r--r--tensorflow/core/framework/run_handler_util_test.cc93
1 files changed, 93 insertions, 0 deletions
diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc
new file mode 100644
index 0000000000..a1928c132b
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util_test.cc
@@ -0,0 +1,93 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <vector>
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+namespace tensorflow {
+namespace {
+
+void VerifyFunction(int num_active_requests, int num_threads,
+ int min_threads_per_request, bool print_stats = false) {
+ if (print_stats) {
+ LOG(INFO) << "Test case# num_active_requests: " << num_active_requests
+ << " num_threads: " << num_threads
+ << " min_threads: " << min_threads_per_request;
+ }
+ std::vector<std::uint_fast32_t> start(num_active_requests);
+ std::vector<std::uint_fast32_t> end(num_active_requests);
+
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ min_threads_per_request, &start, &end);
+ string range_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) range_str += " ";
+ range_str += strings::StrCat("[", start[i], ", ", end[i], ")");
+
+ ASSERT_GE(start[i], 0) << range_str;
+ ASSERT_LE(end[i], num_threads) << range_str;
+ if (i > 0) {
+ // Due to linearly decreasing demand, #threads(i - 1) >= #threads(i)
+ ASSERT_GE(end[i - 1] - start[i - 1], end[i] - start[i]) << range_str;
+ // No missing threads.
+ ASSERT_GE(end[i - 1], start[i]) << range_str;
+ }
+ // Each interval is at least of size 'min_threads_per_request'.
+ ASSERT_GE((end[i] - start[i]), min_threads_per_request) << range_str;
+ // Verify that assigned (quantized) threads is not overly estimated
+ // from real demand, when the demand is high (>=
+ // min_threads_per_request).
+ float entry_weight = num_active_requests - i;
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float thread_demand = (entry_weight * num_threads) / total_weight;
+ if (thread_demand > min_threads_per_request) {
+ // We expect some over-estimation of threads due to quantization,
+ // but we hope it's not more than 1 extra thread.
+ ASSERT_NEAR(end[i] - start[i], thread_demand, 1.0)
+ << "Ranges: " << range_str << " thread_demand: " << thread_demand
+ << " i: " << i;
+ }
+ }
+ ASSERT_EQ(end[num_active_requests - 1], num_threads);
+ ASSERT_EQ(start[0], 0);
+ if (print_stats) {
+ LOG(INFO) << "Assigned ranges: " << range_str;
+ }
+}
+
+TEST(RunHandlerUtilTest, TestComputeInterOpSchedulingRanges) {
+ const int kMinThreadsPerRequestBound = 12;
+ const int kMaxActiveRequests = 128;
+ const int kMaxThreads = 128;
+
+ for (int min_threads_per_request = 1;
+ min_threads_per_request <= kMinThreadsPerRequestBound;
+ ++min_threads_per_request) {
+ for (int num_active_requests = 1; num_active_requests <= kMaxActiveRequests;
+ ++num_active_requests) {
+ for (int num_threads = min_threads_per_request;
+ num_threads <= kMaxThreads; ++num_threads) {
+ VerifyFunction(num_active_requests, num_threads,
+ min_threads_per_request);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow