aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/meta_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/meta_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc68
1 files changed, 66 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 3f33b16ba8..7488cedec5 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+
+#include <memory>
+
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -37,7 +40,11 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -115,6 +122,21 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
#undef MK_OPT
+MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
+ : cpu_device_(cpu_device), cfg_(cfg) {
+ // TODO(rmlarsen): Increase kNumThreads to, say, port::NumSchedulableCPUs()
+ // if we want to the threadpool for parallelizing Grappler
+ const int kNumThreads = 1;
+ thread_pool_ = absl::make_unique<thread::ThreadPool>(
+ Env::Default(), "MetaOptimizerThreadPool", kNumThreads);
+}
+
+MetaOptimizer::~MetaOptimizer() {
+ // The ThreadPool destructor waits for threads to finish, so we don't
+ // pull the rug out from under them.
+ thread_pool_.reset();
+}
+
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
if (cfg_.disable_meta_optimizer()) {
@@ -310,6 +332,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
VLOG(4) << "Starting optimization iteration " << iteration;
for (const auto& optimizer : optimizers) {
+ GRAPPLER_RETURN_IF_CANCELLED();
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
// Some must run only on the last iteration.
@@ -368,6 +391,7 @@ Status MetaOptimizer::RunOptimizer(
// resets optimized_graph to an empty graph.
optimized_graph->Swap(&optimized_item->graph);
*optimized_graph = GraphDef();
+ // TODO(rmlarsen): Add timeout for individual optimizers.
Status status =
optimizer->Optimize(cluster, *optimized_item, optimized_graph);
uint64 end_us = Env::Default()->NowMicros();
@@ -389,14 +413,15 @@ Status MetaOptimizer::RunOptimizer(
return status;
}
-Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
+Status MetaOptimizer::OptimizeMainGraphAndFunctionLibrary(
+ Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) {
VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
VLOG(1) << "Optimized main graph.";
+ GRAPPLER_RETURN_IF_CANCELLED();
// Skip optimizing functions if this is a TPU graph. Currently, Grappler
// passes do not handle TPU functions correctly in a variety of ways (Note
@@ -432,6 +457,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimize_function_library = false;
for (const FunctionDef& func : optimized_graph->library().function()) {
+ GRAPPLER_RETURN_IF_CANCELLED();
+
const string& func_name = func.signature().name();
// Skip already optimized functions.
@@ -506,6 +533,43 @@ void MetaOptimizer::PrintResult() {
}
}
+Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ const int64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000;
+ const int64 timeout_usec = (cfg_.meta_optimizer_timeout_ms() == 0
+ ? kFiveMinutesInUsec
+ : cfg_.meta_optimizer_timeout_ms() * 1000);
+ if (timeout_usec < 0) {
+ return OptimizeMainGraphAndFunctionLibrary(cluster, item, optimized_graph);
+ }
+
+ GraphDef optimized_with_timeout;
+ Status status;
+ Notification done;
+ thread_pool_->Schedule(
+ [this, cluster, &done, &optimized_with_timeout, &item, &status]() {
+ status = this->OptimizeMainGraphAndFunctionLibrary(
+ cluster, item, &optimized_with_timeout);
+ done.Notify();
+ });
+
+ const bool notified = WaitForNotificationWithTimeout(&done, timeout_usec);
+ if (notified && status.ok()) {
+ optimized_graph->Swap(&optimized_with_timeout);
+ } else {
+ *optimized_graph = item.graph;
+ if (!notified) {
+ this->Cancel();
+ done.WaitForNotification();
+ status = errors::DeadlineExceeded(
+ "Grappler MetaOptimizer timed out after ",
+ static_cast<float>(timeout_usec) / (1000 * 1000), " seconds");
+ LOG(WARNING) << status.error_message();
+ }
+ }
+ return status;
+}
+
void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& pruned_graph, double result) {
// Nothing to do for MetaOptimizer.