aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-10-10 08:36:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 08:40:03 -0700
commit79af30d357fbe0869e163e1d9dce0cb869b3724f (patch)
treeaa4789c0aa0e10321afe4d3d84eae5fd0e84af3a
parent131f6f8429ffa0511a3d5a6a595843d3d96ec942 (diff)
[Grappler] Add RemoveStackStridedSliceSameAxis optimizer.
// Replace operations of the form: // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...] // with // a_i // when the strided slice index `i` is applied in the k'th axis. // // Similarly, replace operations of the form: // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...] // with // expand_dims(a_i, axis=k) // PiperOrigin-RevId: 216535346
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc295
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h3
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc211
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h4
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc3
5 files changed, 515 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 7d5014ee0a..0c2686a419 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/strided_slice_op.h"
using tensorflow::strings::StrCat;
@@ -157,6 +158,14 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node);
}
+Status CheckAttrExists(const NodeDef& node, const string& key) {
+ if (node.attr().count(key) == 0) {
+ return errors::InvalidArgument("Node '", node.name(), "'lacks '", key,
+ "' attr: ", node.DebugString());
+ }
+ return Status::OK();
+}
+
NodeDef* GetTailOfValuePreservingChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
@@ -2902,6 +2911,284 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage {
std::unordered_set<string> fused_nodes_;
};
+// Replace operations of the form:
+// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
+// with
+// a_i
+// when the strided slice index `i` is applied in the k'th axis.
+//
+// Similarly, replace operations of the form:
+// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
+// with
+// expand_dims(a_i, axis=k)
+//
+// TODO(ebrevdo): Extend to also replace operations of the form
+// concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
+// with
+// a_i,
+// when
+// s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
+// and slicing is in the k'th axis.
+class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
+ public:
+ explicit RemoveStackStridedSliceSameAxis(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
+ ctx_ext) {}
+ ~RemoveStackStridedSliceSameAxis() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsStridedSlice(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ // *node is a StridedSlice NodeDef.
+ NodeDef* pack;
+
+ // Get the input and see if it's a Pack op.
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
+ if (!IsPack(*pack)) return Status::OK();
+
+ bool return_early;
+ PartialTensorShape pack_output_shape;
+ int pack_axis;
+ TF_RETURN_IF_ERROR(
+ CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
+ if (return_early) return Status::OK();
+
+ int slice_start_value;
+ bool found;
+ TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
+ &slice_start_value, &found));
+ if (!found) return Status::OK();
+
+ return RewriteGraph(node, pack, slice_start_value, pack_axis,
+ simplified_node_name);
+ }
+
+ protected:
+ bool IsReallyConstant(const NodeDef& node) const {
+ if (!IsConstant(node)) {
+ return false;
+ }
+ // If the node is fed it's not constant anymore.
+ return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
+ }
+
+ bool GetConstantAsInt64(const NodeDef& node, DataType dtype,
+ std::vector<int64>* values) {
+ if (dtype == DT_INT32) {
+ std::vector<int32> values_int32;
+ if (!ValuesFromConstNode(node, &values_int32)) {
+ return false;
+ }
+ std::copy(values_int32.begin(), values_int32.end(),
+ std::inserter(*values, values->begin()));
+ return true;
+ } else {
+ return ValuesFromConstNode(node, values);
+ }
+ }
+
+ Status CheckInputs(const NodeDef* node, const NodeDef* pack,
+ PartialTensorShape* pack_output_shape, int* pack_axis,
+ bool* return_early) {
+ *return_early = true;
+ TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
+
+ *pack_axis = pack->attr().at("axis").i();
+ auto slice_properties =
+ ctx().graph_properties->GetInputProperties(node->name());
+ *pack_output_shape = slice_properties[0].shape();
+ if (pack_output_shape->unknown_rank()) {
+ return Status::OK();
+ }
+ const int pack_input_rank = pack_output_shape->dims() - 1;
+ if (*pack_axis < 0) {
+ // The ndims of any input into Pack op is its output ndims - 1.
+ *pack_axis += pack_input_rank;
+ }
+ if (*pack_axis < 0 || *pack_axis >= pack_input_rank) {
+ return errors::InvalidArgument(
+ "Pack node (", pack->name(),
+ ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
+ }
+ *return_early = false;
+ return Status::OK();
+ }
+
+ Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
+ const PartialTensorShape& pack_output_shape,
+ int pack_axis, int* slice_start_value, bool* found) {
+ *found = false;
+ for (auto key : {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask",
+ "shrink_axis_mask"}) {
+ TF_RETURN_IF_ERROR(CheckAttrExists(*node, key));
+ }
+
+ const int begin_mask = node->attr().at("begin_mask").i();
+ const int end_mask = node->attr().at("end_mask").i();
+ const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
+ const int new_axis_mask = node->attr().at("new_axis_mask").i();
+ const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
+
+ // Check that the StridedSlice is one of these at pack_axis:
+ // [..., i, ...]
+ // [..., i:i+1, ...]
+ // [..., :1, ...]
+ // [..., -1:, ...]
+ /// [..., s_{pack_axis}-1:, ...]
+ NodeDef* slice_begin;
+ NodeDef* slice_end;
+ NodeDef* slice_strides;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
+
+ for (const auto* n : {slice_begin, slice_end, slice_strides}) {
+ if (!IsReallyConstant(*n)) return Status::OK();
+ }
+
+ Tensor slice_begin_t;
+ Tensor slice_end_t;
+ Tensor slice_strides_t;
+
+ TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
+ TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
+
+ if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
+ return Status::OK();
+ }
+ if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
+ return Status::OK();
+ }
+ if (!slice_strides_t.FromProto(
+ slice_strides->attr().at("value").tensor())) {
+ return Status::OK();
+ }
+ TensorShape processing_shape;
+ TensorShape final_shape;
+ bool is_identity;
+ bool is_simple_slice;
+ bool slice_dim0;
+ gtl::InlinedVector<int64, 4> slice_begin_vec;
+ gtl::InlinedVector<int64, 4> slice_end_vec;
+ gtl::InlinedVector<int64, 4> slice_strides_vec;
+ TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
+ &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
+ begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
+ &processing_shape, &final_shape, &is_identity, &is_simple_slice,
+ &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
+
+ if (!is_simple_slice) return Status::OK();
+
+ int begin_index = -1;
+ int64 begin_value = 0;
+ for (int i = 0; i < slice_begin_vec.size(); ++i) {
+ const int64 v = slice_begin_vec[i];
+ if (v != 0) {
+ if (begin_index != -1) {
+ // At least two start values that are nonzero.
+ return Status::OK();
+ }
+ begin_index = i;
+ begin_value = v;
+ }
+ }
+
+ int end_index = -1;
+ int64 end_value = 0;
+ for (int i = 0; i < slice_end_vec.size(); ++i) {
+ const int64 v = slice_end_vec[i];
+ if (v != pack_output_shape.dim_size(i)) {
+ if (end_index != -1) {
+ // At least two end values that are nonzero.
+ return Status::OK();
+ }
+ end_index = i;
+ end_value = v;
+ }
+ }
+
+ if (begin_index == -1 && end_index == -1) return Status::OK();
+ if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
+ // Somehow received different axes for begin/end slicing
+ return Status::OK();
+ }
+ const int slice_axis = (begin_index == -1) ? end_index : begin_index;
+ if (slice_axis != pack_axis) {
+ // Not slicing on the same axis as the Pack op.
+ return Status::OK();
+ }
+ *slice_start_value = (begin_index == -1) ? 0 : begin_value;
+ const int64 slice_end_value =
+ (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
+ if (slice_end_value != *slice_start_value + 1) {
+ // Not slicing a single value out.
+ return Status::OK();
+ }
+
+ if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
+ return errors::InvalidArgument(
+ "Node ", node->name(), " requested invalid slice index ",
+ *slice_start_value, " on axis ", slice_axis,
+ " from tensor of shape: ", pack_output_shape.DebugString());
+ }
+
+ *found = true; // slice_start_value is valid.
+ return Status::OK();
+ }
+
+ Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
+ int slice_start_value, int pack_axis,
+ string* simplified_node_name) {
+ OpInfo::TensorProperties input_slice_properties;
+ NodeDef* input_slice;
+ TF_RETURN_IF_ERROR(
+ GetInputNode(pack->input(slice_start_value), &input_slice));
+ TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
+ &input_slice_properties));
+ PartialTensorShape input_slice_shape(input_slice_properties.shape());
+
+ OpInfo::TensorProperties output_properties;
+ TF_RETURN_IF_ERROR(GetTensorProperties(
+ strings::StrCat(node->name(), ":", 0), &output_properties));
+ PartialTensorShape output_shape(output_properties.shape());
+ NodeDef* output =
+ AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
+ if (input_slice_shape.IsCompatibleWith(output_shape)) {
+ output->set_op("Identity");
+ output->set_device(node->device());
+ SetDataTypeToAttr(output_properties.dtype(), "T", output);
+ output->add_input(input_slice->name());
+ } else {
+ NodeDef* axis = AddEmptyNode(
+ OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
+ axis->set_op("Const");
+ axis->set_device(node->device());
+ auto axis_attr = axis->mutable_attr();
+ SetDataTypeToAttr(DT_INT32, "dtype", axis);
+ auto* axis_t = (*axis_attr)["value"].mutable_tensor();
+ axis_t->set_dtype(DT_INT32);
+ axis_t->add_int_val(pack_axis);
+ AddToOptimizationQueue(axis);
+ output->set_op("ExpandDims");
+ output->set_device(node->device());
+ SetDataTypeToAttr(output_properties.dtype(), "T", output);
+ output->add_input(input_slice->name());
+ output->add_input(axis->name());
+ }
+
+ // Copy dependencies over.
+ ForwardControlDependencies(output, {node, pack});
+ AddToOptimizationQueue(output);
+ *simplified_node_name = output->name();
+
+ return Status::OK();
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -3132,7 +3419,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
graph_properties_.get(), node_map_.get(),
- opt_level_);
+ &feed_nodes_, opt_level_);
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
@@ -3186,6 +3473,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
if (options_.unary_ops_composition)
pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
+ if (options_.remove_stack_strided_slice_same_axis)
+ pipeline.AddStage<RemoveStackStridedSliceSameAxis>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
@@ -3249,6 +3538,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
+ for (const auto& feed : item.feed) {
+ feed_nodes_.insert(NodeName(feed.first));
+ }
+
// Disable restricted graph rewrites.
options_.unary_ops_composition &=
item.allowed_optimizations.non_differentiable_rewrites;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index d457eb6d21..bb56f61e30 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -79,6 +80,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool convert_log1p = true;
bool convert_expm1 = true;
bool unary_ops_composition = true;
+ bool remove_stack_strided_slice_same_axis = false;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
@@ -128,6 +130,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
std::unique_ptr<NodeMap> node_map_;
std::unique_ptr<GraphProperties> graph_properties_;
GraphDef* optimized_graph_ = nullptr; // Not owned.
+ gtl::FlatSet<string> feed_nodes_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 77f3c64c65..d091b26b65 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -288,6 +288,12 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.unary_ops_composition = true;
}
+
+ void EnableOnlyRemoveStackStridedSliceSameAxis(
+ ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_stack_strided_slice_same_axis = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -3364,5 +3370,210 @@ TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
+TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto a_in =
+ ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+ auto b_in =
+ ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
+ auto c_in =
+ ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
+ auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
+ PartialTensorShape({-1, -1}));
+ auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
+ PartialTensorShape({-1, -1}));
+ auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
+ PartialTensorShape({-1, -1}));
+ // stacked = tf.stack((a, b, c), axis=1).
+ // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
+ auto stacked =
+ ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
+ ops::Stack::Axis(1));
+ auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
+ auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
+ auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
+ auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
+ auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
+ auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
+ auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
+ auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
+ auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
+ auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
+ auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
+
+ // stacked[:, 0]
+ using SS = ops::StridedSlice;
+ auto pa_slice = ops::Identity(
+ s.WithOpName("pa_slice_out"),
+ SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0b0010))); // 2
+
+ // stacked[:, 1]
+ auto pb_slice = ops::Identity(
+ s.WithOpName("pb_slice_out"),
+ SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0b0010))); // 2
+
+ // stacked[:, 2]
+ auto pc_slice = ops::Identity(
+ s.WithOpName("pc_slice_out"),
+ SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0b0010))); // 2
+
+ // stacked[:, 0:1, :]
+ auto pa_slice_01 = ops::Identity(
+ s.WithOpName("pa_slice_01_out"),
+ SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0)));
+
+ // stacked[:, :1, :]
+ auto pa_slice_to1 = ops::Identity(
+ s.WithOpName("pa_slice_to1_out"),
+ SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
+ SS::BeginMask(0b0111) // 7
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0)));
+
+ // stacked[:, 1:2, :]
+ auto pb_slice_12 = ops::Identity(
+ s.WithOpName("pb_slice_12_out"),
+ SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0)));
+
+ // stacked[:, 2:, :].
+ auto pc_slice_2to = ops::Identity(
+ s.WithOpName("pc_slice_2to_out"),
+ SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0111) // 7
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0)));
+
+ GrapplerItem item;
+ item.fetch = {"a",
+ "b",
+ "c",
+ "pa_slice_out",
+ "pb_slice_out",
+ "pc_slice_out",
+ "expanded_a",
+ "expanded_b",
+ "expanded_c",
+ "pa_slice_01_out",
+ "pa_slice_to1_out",
+ "pb_slice_12_out",
+ "pc_slice_2to_out"};
+ enum FetchItem {
+ fA,
+ fB,
+ fC,
+ fASliceOut,
+ fBSliceOut,
+ fCSliceOut,
+ fExpandedA,
+ fExpandedB,
+ fExpandedC,
+ fASlice01Out,
+ fASliceTo1Out,
+ fBSlice12Out,
+ fCSlice2ToOut,
+ };
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ // stacked[:, 0, :] == a.
+ test::ExpectTensorEqual<float>(tensors_expected[fA],
+ tensors_expected[fASliceOut]);
+ // stacked[:, 1, :] == b.
+ test::ExpectTensorEqual<float>(tensors_expected[fB],
+ tensors_expected[fBSliceOut]);
+ // stacked[:, 2, :] == c.
+ test::ExpectTensorEqual<float>(tensors_expected[fC],
+ tensors_expected[fCSliceOut]);
+
+ // stacked[:, 0:1, :] == expand_dims(a, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
+ tensors_expected[fASlice01Out]);
+
+ // stacked[:, :1, :] == expand_dims(a, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
+ tensors_expected[fASliceTo1Out]);
+
+ // stacked[:, 1:2, :] == expand_dims(b, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedB],
+ tensors_expected[fBSlice12Out]);
+ // stacked[:, 2:, :] == expand_dims(c, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedC],
+ tensors_expected[fCSlice2ToOut]);
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveStackStridedSliceSameAxis(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ for (const auto& node : output.node()) {
+ if (node.name() == "pa_slice_out") {
+ EXPECT_EQ(node.input(0), "a");
+ } else if (node.name() == "pb_slice_out") {
+ EXPECT_EQ(node.input(0), "b");
+ } else if (node.name() == "pc_slice_out") {
+ EXPECT_EQ(node.input(0), "c");
+ } else if (str_util::EndsWith(node.name(), "_out")) {
+ EXPECT_EQ(strings::StrCat(node.input(0), "_out"),
+ strings::StrCat(
+ "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
+ node.name()));
+ }
+ }
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+
+ // stacked[:, 0, :] == a.
+ test::ExpectTensorEqual<float>(tensors_expected[fA], tensors[fASliceOut]);
+
+ // stacked[:, 1, :] == b.
+ test::ExpectTensorEqual<float>(tensors_expected[fB], tensors[fBSliceOut]);
+ // stacked[:, 2, :] == c.
+ test::ExpectTensorEqual<float>(tensors_expected[fC], tensors[fCSliceOut]);
+
+ // stacked[:, 0:1, :] == expand_dims(a, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
+ tensors[fASlice01Out]);
+
+ // stacked[:, :1, :] == expand_dims(a, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
+ tensors[fASliceTo1Out]);
+
+ // stacked[:, 1:2, :] == expand_dims(b, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedB],
+ tensors[fBSlice12Out]);
+ // stacked[:, 2:, :] == expand_dims(c, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedC],
+ tensors[fCSlice2ToOut]);
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 2afb5df431..f31a30ec0e 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -46,17 +47,20 @@ struct GraphOptimizerContext {
GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph,
GraphProperties* graph_properties, NodeMap* node_map,
+ gtl::FlatSet<string>* feed_nodes,
RewriterConfig::Toggle opt_level)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
graph_properties(graph_properties),
node_map(node_map),
+ feed_nodes(feed_nodes),
opt_level(opt_level) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map;
+ gtl::FlatSet<string>* feed_nodes;
RewriterConfig::Toggle opt_level;
};
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
index 34f28c7c27..799c40c67b 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
@@ -61,6 +61,7 @@ TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
/*optimized_graph*/ nullptr,
/*graph_properties*/ nullptr,
/*node_name*/ nullptr,
+ /*feed_nodes*/ nullptr,
/*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
@@ -97,6 +98,7 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
/*node_name*/ &node_map,
+ /*feed_nodes*/ nullptr,
/*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
@@ -137,6 +139,7 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
/*node_name*/ &node_map,
+ /*feed_nodes*/ nullptr,
/*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);