aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/optimization_registry.h
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-01-14 07:35:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-14 07:45:34 -0800
commit8eb161e39185d364e04a6edf2beadc2bb5cb978c (patch)
tree1256e6088fdf998e76a430c791e022f52b7bd16c /tensorflow/core/common_runtime/optimization_registry.h
parent7ab67d8b526d6f01c8c8347993d68f8f5074c184 (diff)
Add support for passes that run post-partitioning to OptimizationRegistry.
To avoid another GraphDef -> Graph -> GraphDef conversion, change Device::MaybeRewriteGraph to take a Graph instead of a GraphDef. Use std::unique_ptr<> in more places to avoid some awkward .release() magic. Change: 144532446
Diffstat (limited to 'tensorflow/core/common_runtime/optimization_registry.h')
-rw-r--r--tensorflow/core/common_runtime/optimization_registry.h12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/optimization_registry.h b/tensorflow/core/common_runtime/optimization_registry.h
index 45c571f847..adfa17ae9d 100644
--- a/tensorflow/core/common_runtime/optimization_registry.h
+++ b/tensorflow/core/common_runtime/optimization_registry.h
@@ -39,8 +39,17 @@ struct GraphOptimizationPassOptions {
const CostModel* cost_model = nullptr;
FunctionLibraryDefinition* flib_def = nullptr; // Not owned.
+
+ // The graph to optimize, for optimization passes that run before
+ // partitioning. Null for post-partitioning passes.
// An optimization pass may replace *graph with a new graph object.
- std::unique_ptr<Graph>* graph;
+ std::unique_ptr<Graph>* graph = nullptr;
+
+ // Graphs for each partition, if running post-partitioning. Optimization
+ // passes may alter the graphs, but must not add or remove partitions.
+ // Null for pre-partitioning passes.
+ std::unordered_map<string, std::unique_ptr<Graph>>* partition_graphs =
+ nullptr;
};
// Optimization passes are implemented by inheriting from
@@ -64,6 +73,7 @@ class OptimizationPassRegistry {
PRE_PLACEMENT, // after cost model assignment, before placement.
POST_PLACEMENT, // after placement.
POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints.
+ POST_PARTITIONING, // after partitioning
};
// Add an optimization pass to the registry.