aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-10-10 08:36:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 08:40:03 -0700
commit79af30d357fbe0869e163e1d9dce0cb869b3724f (patch)
treeaa4789c0aa0e10321afe4d3d84eae5fd0e84af3a /tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
parent131f6f8429ffa0511a3d5a6a595843d3d96ec942 (diff)
[Grappler] Add RemoveStackStridedSliceSameAxis optimizer.
// Replace operations of the form: // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...] // with // a_i // when the strided slice index `i` is applied in the k'th axis. // // Similarly, replace operations of the form: // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...] // with // expand_dims(a_i, axis=k) // PiperOrigin-RevId: 216535346
Diffstat (limited to 'tensorflow/core/grappler/optimizers/graph_optimizer_stage.h')
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 2afb5df431..f31a30ec0e 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -46,17 +47,20 @@ struct GraphOptimizerContext {
GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph,
GraphProperties* graph_properties, NodeMap* node_map,
+ gtl::FlatSet<string>* feed_nodes,
RewriterConfig::Toggle opt_level)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
graph_properties(graph_properties),
node_map(node_map),
+ feed_nodes(feed_nodes),
opt_level(opt_level) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map;
+ gtl::FlatSet<string>* feed_nodes;
RewriterConfig::Toggle opt_level;
};