aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/graph_optimizer.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/graph_optimizer.h')
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer.h21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h
index 765dd13263..bd6bf9f860 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h
@@ -16,8 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
+#include <atomic>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace grappler {
@@ -29,6 +32,7 @@ struct GrapplerItem;
// optimization of a GrapplerItem for running on a cluster.
class GraphOptimizer {
public:
+ GraphOptimizer() : is_cancelled_(false) {}
virtual ~GraphOptimizer() {}
virtual string name() const = 0;
@@ -45,8 +49,25 @@ class GraphOptimizer {
// call to Optimize) performed. Lower "result" scores are better.
virtual void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) = 0;
+
+ // Best effort cancellation. Sets is_cancelled to true and requests that the
+ // optimizer returns as soon as possible from active calls to Optimize() or
+ // FeedBack().
+ void Cancel() { is_cancelled_ = true; }
+
+ bool is_cancelled() const { return is_cancelled_; }
+
+ private:
+ std::atomic<bool> is_cancelled_;
};
+#define GRAPPLER_RETURN_IF_CANCELLED() \
+ do { \
+ if (is_cancelled()) { \
+ return errors::DeadlineExceeded(this->name(), " was cancelled."); \
+ } \
+ } while (0)
+
} // end namespace grappler
} // end namespace tensorflow