aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/BUILD1
-rw-r--r--tensorflow/core/grappler/clusters/BUILD1
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc194
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.h11
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h3
-rw-r--r--tensorflow/core/grappler/utils.cc24
-rw-r--r--tensorflow/core/grappler/utils.h13
-rw-r--r--tensorflow/core/grappler/utils_test.cc21
8 files changed, 158 insertions, 110 deletions
diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
index 53714367f5..6d77df6cb3 100644
--- a/tensorflow/core/grappler/BUILD
+++ b/tensorflow/core/grappler/BUILD
@@ -32,6 +32,7 @@ cc_test(
srcs = ["utils_test.cc"],
deps = [
":utils",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index c420818333..1bfe69545b 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -44,6 +44,7 @@ cc_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//tensorflow/core:lib",
+ "//tensorflow/core/grappler:utils",
"//tensorflow/core/kernels:ops_util",
],
)
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index eb1ebf63da..6cf3ed337d 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/cc/training/queue_runner.h"
+#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -26,7 +27,6 @@ namespace grappler {
SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus)
: Cluster(timeout_s),
num_gpus_(num_gpus),
- running_(false),
closing_(false) {
thread_pool_.reset(new thread::ThreadPool(
Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
@@ -38,13 +38,17 @@ SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus)
CHECK_GE(num_cpu_cores, 1);
options_.config.set_intra_op_parallelism_threads(num_cpu_cores);
options_.config.set_inter_op_parallelism_threads(num_cpu_cores);
+ if (timeout_s > 0) {
+ options_.config.set_operation_timeout_in_ms(timeout_s * 1000);
+ }
}
SingleMachine::~SingleMachine() {
CloseSession(false /*use_timeout*/).IgnoreError();
- // Prevent the destructor from deleting mu_ until CloseSession() is done.
- mutex_lock l(mu_);
+ // Reset the thread-pool so that there are no outstanding Session::Run(...)s
+ // when we delete the session.
+ thread_pool_.reset();
}
Status SingleMachine::Provision() {
@@ -68,6 +72,7 @@ Status SingleMachine::Provision() {
}
Status SingleMachine::Initialize(const GrapplerItem& item) {
+ mutex_lock l(this->last_graph_mu_);
if (last_graph_ != &item.graph || last_graph_id_ != item.id) {
init_ops_ = item.init_ops;
last_graph_ = nullptr;
@@ -81,36 +86,42 @@ Status SingleMachine::Run(const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch,
RunMetadata* metadata) {
- if (last_graph_ != &graph_def) {
- Status status = ResetSession();
- if (status.ok()) {
- status = session_->Create(graph_def);
- }
- if (!init_ops_.empty() && status.ok()) {
- status = RunWithTimeout({}, init_ops_, nullptr);
- }
- for (int i = 0; i < queue_runner_defs_.size() && status.ok(); ++i) {
- std::unique_ptr<QueueRunner> queue_runner;
- TF_RETURN_IF_ERROR(QueueRunner::New(queue_runner_defs_[i],
- coordinator_.get(), &queue_runner));
- TF_RETURN_IF_ERROR(queue_runner->Start(session_.get()));
- TF_RETURN_IF_ERROR(coordinator_->RegisterRunner(std::move(queue_runner)));
- status = coordinator_->GetStatus();
- }
-
- if (status.ok()) {
- last_graph_ = &graph_def;
- } else {
- return status;
- }
+ // Interface idea: What about having Initialize(item, graph_def), which
+ // initializes the graph, and then Run(feed, fetch, metadata).
+ {
+ mutex_lock l(this->last_graph_mu_);
+ if (last_graph_ != &graph_def) {
+ Status status = ResetSession();
+ if (status.ok()) {
+ status = session_->Create(graph_def);
+ }
+ if (!init_ops_.empty() && status.ok()) {
+ status = RunWithTimeout({}, init_ops_, nullptr);
+ }
+ for (int i = 0; i < queue_runner_defs_.size() && status.ok(); ++i) {
+ std::unique_ptr<QueueRunner> queue_runner;
+ TF_RETURN_IF_ERROR(QueueRunner::New(queue_runner_defs_[i],
+ coordinator_.get(), &queue_runner));
+ TF_RETURN_IF_ERROR(queue_runner->Start(session_.get()));
+ TF_RETURN_IF_ERROR(
+ coordinator_->RegisterRunner(std::move(queue_runner)));
+ status = coordinator_->GetStatus();
+ }
- // Warmup TensorFlow if needed
- for (int i = 0;
- i < options_.config.graph_options().build_cost_model_after(); ++i) {
- status = RunWithTimeout(feed, fetch, nullptr);
- if (!status.ok()) {
+ if (status.ok()) {
+ last_graph_ = &graph_def;
+ } else {
return status;
}
+
+ // Warmup TensorFlow if needed
+ for (int i = 0;
+ i < options_.config.graph_options().build_cost_model_after(); ++i) {
+ status = RunWithTimeout(feed, fetch, nullptr);
+ if (!status.ok()) {
+ return status;
+ }
+ }
}
}
@@ -125,37 +136,31 @@ Status SingleMachine::AllowSoftPlacement(bool soft_placement_state) {
Status SingleMachine::RunWithTimeout(
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch, RunMetadata* run_metadata) {
- mutex_lock l(mu_);
// We shouldn't be running or closing the session at this point.
- CHECK(!running_);
- CHECK(!closing_);
-
- running_ = true;
- metadata_ = RunMetadata();
-
- thread_pool_->Schedule([this, feed, fetch] {
- Status status =
- session_->Run(run_options_, feed, {}, fetch, nullptr, &this->metadata_);
- mutex_lock l(mu_);
- status_ = status;
- running_ = false;
- done_running_.notify_all();
- });
-
- while (running_) {
- std::cv_status timeout =
- done_running_.wait_for(l, std::chrono::milliseconds(timeout_s_ * 1000));
- if (timeout != std::cv_status::no_timeout) {
- last_graph_ = nullptr;
- return Status(error::DEADLINE_EXCEEDED,
- strings::StrCat("Failed to run the graph after ",
- timeout_s_, " seconds, aborting"));
- }
+ {
+ mutex_lock l(close_mu_);
+ CHECK(!closing_);
}
- if (run_metadata && status_.ok()) {
- *run_metadata = metadata_;
+ auto status = std::make_shared<Status>();
+ const bool executed_in_time = ExecuteWithTimeout(
+ [status, this, &run_metadata, &feed, &fetch]() {
+ if (!run_metadata) {
+ RunMetadata unused;
+ *status =
+ session_->Run(run_options_, feed, {}, fetch, nullptr, &unused);
+ } else {
+ *status = session_->Run(run_options_, feed, {}, fetch, nullptr,
+ run_metadata);
+ }
+ },
+ timeout_s_ * 1000, thread_pool_.get());
+ if (!executed_in_time) {
+ mutex_lock l(last_graph_mu_);
+ last_graph_ = nullptr;
+ return errors::DeadlineExceeded("Failed to run the graph after ",
+ timeout_s_, " seconds, aborting");
}
- return status_;
+ return *status;
}
Status SingleMachine::CloseSession(bool use_timeout) {
@@ -163,54 +168,41 @@ Status SingleMachine::CloseSession(bool use_timeout) {
return Status::OK();
}
- mutex_lock l(close_mu_);
+ {
+ mutex_lock l(close_mu_);
- if (!closing_) {
- closing_ = true;
+ if (!closing_) {
+ closing_ = true;
+ }
+ }
- thread_pool_->Schedule([this] {
- if (this->coordinator_) {
- this->coordinator_->RequestStop().IgnoreError();
- // Wait for all the runners to have closed their queues.
- while (!this->coordinator_->AllRunnersStopped()) {
- sleep(1);
+ const bool executed_in_time = ExecuteWithTimeout(
+ [&]() {
+ if (this->coordinator_) {
+ this->coordinator_->RequestStop().IgnoreError();
+ // Wait for all the runners to have closed their queues.
+ while (!this->coordinator_->AllRunnersStopped()) {
+ sleep(1);
+ }
+ // Now we can close the session. This should cancel any pending I/O
+ // operation.
+ this->session_->Close().IgnoreError();
+ // Last but not least, we can delete the coordinator.
+ this->coordinator_.reset();
+ } else {
+ this->session_->Close().IgnoreError();
}
- // Now we can close the session. This should cancel any pending I/O
- // operation.
- this->session_->Close().IgnoreError();
- // Last but not least, we can delete the coordinator.
- this->coordinator_.reset();
- } else {
- this->session_->Close().IgnoreError();
- }
-
- // Wait for any previous run to finish.
- mutex_lock l(mu_);
- while (running_) {
- done_running_.wait(l);
- }
- mutex_lock l2(close_mu_);
- closing_ = false;
- done_closing_.notify_all();
- });
- }
+ mutex_lock l2(close_mu_);
+ closing_ = false;
+ },
+ use_timeout ? timeout_s_ * 1000 : -1, thread_pool_.get());
- while (closing_) {
- if (!use_timeout) {
- done_closing_.wait(l);
- } else {
- std::cv_status timeout = done_closing_.wait_for(
- l, std::chrono::milliseconds(timeout_s_ * 1000));
- if (timeout != std::cv_status::no_timeout) {
- // Let the caller know that we can't shutdown the session, and therefore
- // can't process any further.
- return Status(
- error::UNAVAILABLE,
- strings::StrCat("Failed to close the previous session after ",
- timeout_s_, " seconds, aborting"));
- }
- }
+ if (!executed_in_time) {
+ // Let the caller know that we can't shutdown the session, and therefore
+ // can't process any further.
+ return errors::Unavailable("Failed to close the previous session after ",
+ timeout_s_, " seconds, aborting");
}
return Status::OK();
diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h
index 5f3dd6f8ce..168110824d 100644
--- a/tensorflow/core/grappler/clusters/single_machine.h
+++ b/tensorflow/core/grappler/clusters/single_machine.h
@@ -48,22 +48,17 @@ class SingleMachine : public Cluster {
const int num_gpus_;
std::unique_ptr<Session> session_;
std::vector<QueueRunnerDef> queue_runner_defs_;
- const GraphDef* last_graph_ = nullptr; // Not owned.
string last_graph_id_;
+ mutex last_graph_mu_;
+ const GraphDef* last_graph_ GUARDED_BY(last_graph_mu_) = nullptr;
std::vector<string> init_ops_;
std::unique_ptr<Coordinator> coordinator_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
Status status_;
- RunMetadata metadata_;
-
- mutex mu_;
- bool running_;
- condition_variable done_running_;
mutex close_mu_;
- bool closing_;
- condition_variable done_closing_;
+ bool closing_ GUARDED_BY(close_mu_);
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index 3c65c34f8d..093b7e29dc 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -68,7 +68,10 @@ struct Costs {
typedef NanoSeconds Duration;
// Overall cost of running the graph; latency.
+ // Mean
Duration execution_time;
+ Duration min_execution_time;
+ Duration max_execution_time;
// Computation cost of running the graph.
Duration compute_time;
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 35e62432bb..03ae6b8278 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/utils.h"
+#include <memory>
+
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
@@ -96,5 +99,24 @@ string AddPrefixToNodeName(const string& name, const string& prefix) {
return strings::StrCat(prefix, "-", name);
}
+bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
+ thread::ThreadPool* const thread_pool) {
+ if (timeout_in_ms <= 0) {
+ fn();
+ return true;
+ }
+ auto done = std::make_shared<Notification>();
+ thread_pool->Schedule([done, &fn]() {
+ fn();
+ done->Notify();
+ });
+ const bool notified =
+ WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
+ if (!notified) {
+ return false;
+ }
+ return true;
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 98d75fd327..2abd200cd8 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -16,6 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_GRAPPLER_UTILS_H_
#define TENSORFLOW_GRAPPLER_UTILS_H_
+#include <functional>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -38,6 +42,15 @@ int NodePosition(const string& name);
// Add a prefix to a node name
string AddPrefixToNodeName(const string& name, const string& prefix);
+// Executes a 'fn' in the 'thread_pool'. The method waits for the configured
+// timeout (in milliseconds) for 'fn' to complete, before returning false.
+//
+// If returning false, the 'fn' may still continue to execute in the
+// thread-pool. It is the responsibility of the caller to reset the thread-pool
+// as appropriate.
+bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
+ thread::ThreadPool* thread_pool);
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index f5d6877637..766cedd65d 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -60,6 +63,24 @@ TEST_F(UtilsTest, AddNodeNamePrefix) {
EXPECT_EQ("OPTIMIZED-", AddPrefixToNodeName("", "OPTIMIZED"));
}
+TEST_F(UtilsTest, ExecuteWithTimeout) {
+ std::unique_ptr<thread::ThreadPool> thread_pool(
+ new thread::ThreadPool(Env::Default(), "ExecuteWithTimeout", 2));
+ ASSERT_TRUE(ExecuteWithTimeout(
+ []() { // Do nothing.
+ },
+ 1 /* timeout_in_ms */, thread_pool.get()));
+ // This should time out.
+ ASSERT_FALSE(ExecuteWithTimeout([]() { sleep(1); }, 1 /* timeout_in_ms */,
+ thread_pool.get()));
+ // This should run till the end.
+ ASSERT_TRUE(ExecuteWithTimeout([]() { sleep(1); }, 0 /* timeout_in_ms */,
+ thread_pool.get()));
+
+ // Deleting before local variables go off the stack.
+ thread_pool.reset();
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow