diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/graph_optimizer.h')
-rw-r--r-- | tensorflow/core/grappler/optimizers/graph_optimizer.h | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h index 765dd13263..bd6bf9f860 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h @@ -16,8 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ +#include <atomic> #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace tensorflow { namespace grappler { @@ -29,6 +32,7 @@ struct GrapplerItem; // optimization of a GrapplerItem for running on a cluster. class GraphOptimizer { public: + GraphOptimizer() : is_cancelled_(false) {} virtual ~GraphOptimizer() {} virtual string name() const = 0; @@ -45,8 +49,25 @@ class GraphOptimizer { // call to Optimize) performed. Lower "result" scores are better. virtual void Feedback(Cluster* cluster, const GrapplerItem& item, const GraphDef& optimized_graph, double result) = 0; + + // Best effort cancellation. Sets is_cancelled to true and requests that the + // optimizer returns as soon as possible from active calls to Optimize() or + // FeedBack(). + void Cancel() { is_cancelled_ = true; } + + bool is_cancelled() const { return is_cancelled_; } + + private: + std::atomic<bool> is_cancelled_; }; +#define GRAPPLER_RETURN_IF_CANCELLED() \ + do { \ + if (is_cancelled()) { \ + return errors::DeadlineExceeded(this->name(), " was cancelled."); \ + } \ + } while (0) + } // end namespace grappler } // end namespace tensorflow |