diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-10-10 08:36:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-10 08:40:03 -0700 |
commit | 79af30d357fbe0869e163e1d9dce0cb869b3724f (patch) | |
tree | aa4789c0aa0e10321afe4d3d84eae5fd0e84af3a /tensorflow/core/grappler/optimizers/graph_optimizer_stage.h | |
parent | 131f6f8429ffa0511a3d5a6a595843d3d96ec942 (diff) |
[Grappler] Add RemoveStackStridedSliceSameAxis optimizer.
// Replace operations of the form:
// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
// with
// a_i
// when the strided slice index `i` is applied in the k'th axis.
//
// Similarly, replace operations of the form:
// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
// with
// expand_dims(a_i, axis=k)
//
PiperOrigin-RevId: 216535346
Diffstat (limited to 'tensorflow/core/grappler/optimizers/graph_optimizer_stage.h')
-rw-r--r-- | tensorflow/core/grappler/optimizers/graph_optimizer_stage.h | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index 2afb5df431..f31a30ec0e 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -46,17 +47,20 @@ struct GraphOptimizerContext { GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve, GraphDef* optimized_graph, GraphProperties* graph_properties, NodeMap* node_map, + gtl::FlatSet<string>* feed_nodes, RewriterConfig::Toggle opt_level) : nodes_to_preserve(nodes_to_preserve), optimized_graph(optimized_graph), graph_properties(graph_properties), node_map(node_map), + feed_nodes(feed_nodes), opt_level(opt_level) {} const std::unordered_set<string>* nodes_to_preserve; GraphDef* optimized_graph; GraphProperties* graph_properties; NodeMap* node_map; + gtl::FlatSet<string>* feed_nodes; RewriterConfig::Toggle opt_level; }; |