From 79af30d357fbe0869e163e1d9dce0cb869b3724f Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 10 Oct 2018 08:36:36 -0700 Subject: [Grappler] Add RemoveStackStridedSliceSameAxis optimizer. // Replace operations of the form: // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...] // with // a_i // when the strided slice index `i` is applied in the k'th axis. // // Similarly, replace operations of the form: // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...] // with // expand_dims(a_i, axis=k) // PiperOrigin-RevId: 216535346 --- .../grappler/optimizers/arithmetic_optimizer.cc | 295 ++++++++++++++++++++- .../grappler/optimizers/arithmetic_optimizer.h | 3 + .../optimizers/arithmetic_optimizer_test.cc | 211 +++++++++++++++ .../grappler/optimizers/graph_optimizer_stage.h | 4 + .../optimizers/graph_optimizer_stage_test.cc | 3 + 5 files changed, 515 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 7d5014ee0a..0c2686a419 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/strided_slice_op.h" using tensorflow::strings::StrCat; @@ -157,6 +158,14 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node); } +Status CheckAttrExists(const NodeDef& node, const string& key) { + if (node.attr().count(key) == 0) { + return errors::InvalidArgument("Node '", node.name(), "'lacks '", key, + "' attr: ", node.DebugString()); + } + return Status::OK(); +} + NodeDef* GetTailOfValuePreservingChain( const NodeDef& node, const NodeMap& node_map, const std::unordered_set& nodes_to_preserve) { @@ -2902,6 +2911,284 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage { std::unordered_set fused_nodes_; }; +// Replace operations of the form: +// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...] +// with +// a_i +// when the strided slice index `i` is applied in the k'th axis. +// +// Similarly, replace operations of the form: +// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...] +// with +// expand_dims(a_i, axis=k) +// +// TODO(ebrevdo): Extend to also replace operations of the form +// concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...] +// with +// a_i, +// when +// s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i] +// and slicing is in the k'th axis. +class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage { + public: + explicit RemoveStackStridedSliceSameAxis( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx, + ctx_ext) {} + ~RemoveStackStridedSliceSameAxis() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsStridedSlice(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + // *node is a StridedSlice NodeDef. + NodeDef* pack; + + // Get the input and see if it's a Pack op. + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack)); + if (!IsPack(*pack)) return Status::OK(); + + bool return_early; + PartialTensorShape pack_output_shape; + int pack_axis; + TF_RETURN_IF_ERROR( + CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early)); + if (return_early) return Status::OK(); + + int slice_start_value; + bool found; + TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis, + &slice_start_value, &found)); + if (!found) return Status::OK(); + + return RewriteGraph(node, pack, slice_start_value, pack_axis, + simplified_node_name); + } + + protected: + bool IsReallyConstant(const NodeDef& node) const { + if (!IsConstant(node)) { + return false; + } + // If the node is fed it's not constant anymore. + return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end(); + } + + bool GetConstantAsInt64(const NodeDef& node, DataType dtype, + std::vector* values) { + if (dtype == DT_INT32) { + std::vector values_int32; + if (!ValuesFromConstNode(node, &values_int32)) { + return false; + } + std::copy(values_int32.begin(), values_int32.end(), + std::inserter(*values, values->begin())); + return true; + } else { + return ValuesFromConstNode(node, values); + } + } + + Status CheckInputs(const NodeDef* node, const NodeDef* pack, + PartialTensorShape* pack_output_shape, int* pack_axis, + bool* return_early) { + *return_early = true; + TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis")); + + *pack_axis = pack->attr().at("axis").i(); + auto slice_properties = + ctx().graph_properties->GetInputProperties(node->name()); + *pack_output_shape = slice_properties[0].shape(); + if (pack_output_shape->unknown_rank()) { + return Status::OK(); + } + const int pack_input_rank = pack_output_shape->dims() - 1; + if (*pack_axis < 0) { + // The ndims of any input into Pack op is its output ndims - 1. + *pack_axis += pack_input_rank; + } + if (*pack_axis < 0 || *pack_axis >= pack_input_rank) { + return errors::InvalidArgument( + "Pack node (", pack->name(), + ") axis attribute is out of bounds: ", pack->attr().at("axis").i()); + } + *return_early = false; + return Status::OK(); + } + + Status GetSliceAxis(const NodeDef* node, const NodeDef* pack, + const PartialTensorShape& pack_output_shape, + int pack_axis, int* slice_start_value, bool* found) { + *found = false; + for (auto key : {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", + "shrink_axis_mask"}) { + TF_RETURN_IF_ERROR(CheckAttrExists(*node, key)); + } + + const int begin_mask = node->attr().at("begin_mask").i(); + const int end_mask = node->attr().at("end_mask").i(); + const int ellipsis_mask = node->attr().at("ellipsis_mask").i(); + const int new_axis_mask = node->attr().at("new_axis_mask").i(); + const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i(); + + // Check that the StridedSlice is one of these at pack_axis: + // [..., i, ...] + // [..., i:i+1, ...] + // [..., :1, ...] + // [..., -1:, ...] + /// [..., s_{pack_axis}-1:, ...] + NodeDef* slice_begin; + NodeDef* slice_end; + NodeDef* slice_strides; + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides)); + + for (const auto* n : {slice_begin, slice_end, slice_strides}) { + if (!IsReallyConstant(*n)) return Status::OK(); + } + + Tensor slice_begin_t; + Tensor slice_end_t; + Tensor slice_strides_t; + + TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value")); + TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value")); + + if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) { + return Status::OK(); + } + if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) { + return Status::OK(); + } + if (!slice_strides_t.FromProto( + slice_strides->attr().at("value").tensor())) { + return Status::OK(); + } + TensorShape processing_shape; + TensorShape final_shape; + bool is_identity; + bool is_simple_slice; + bool slice_dim0; + gtl::InlinedVector slice_begin_vec; + gtl::InlinedVector slice_end_vec; + gtl::InlinedVector slice_strides_vec; + TF_RETURN_IF_ERROR(ValidateStridedSliceOp( + &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape, + begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, + &processing_shape, &final_shape, &is_identity, &is_simple_slice, + &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec)); + + if (!is_simple_slice) return Status::OK(); + + int begin_index = -1; + int64 begin_value = 0; + for (int i = 0; i < slice_begin_vec.size(); ++i) { + const int64 v = slice_begin_vec[i]; + if (v != 0) { + if (begin_index != -1) { + // At least two start values that are nonzero. + return Status::OK(); + } + begin_index = i; + begin_value = v; + } + } + + int end_index = -1; + int64 end_value = 0; + for (int i = 0; i < slice_end_vec.size(); ++i) { + const int64 v = slice_end_vec[i]; + if (v != pack_output_shape.dim_size(i)) { + if (end_index != -1) { + // At least two end values that are nonzero. + return Status::OK(); + } + end_index = i; + end_value = v; + } + } + + if (begin_index == -1 && end_index == -1) return Status::OK(); + if (begin_index != -1 && end_index != -1 && begin_index != end_index) { + // Somehow received different axes for begin/end slicing + return Status::OK(); + } + const int slice_axis = (begin_index == -1) ? end_index : begin_index; + if (slice_axis != pack_axis) { + // Not slicing on the same axis as the Pack op. + return Status::OK(); + } + *slice_start_value = (begin_index == -1) ? 0 : begin_value; + const int64 slice_end_value = + (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value; + if (slice_end_value != *slice_start_value + 1) { + // Not slicing a single value out. + return Status::OK(); + } + + if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) { + return errors::InvalidArgument( + "Node ", node->name(), " requested invalid slice index ", + *slice_start_value, " on axis ", slice_axis, + " from tensor of shape: ", pack_output_shape.DebugString()); + } + + *found = true; // slice_start_value is valid. + return Status::OK(); + } + + Status RewriteGraph(const NodeDef* node, const NodeDef* pack, + int slice_start_value, int pack_axis, + string* simplified_node_name) { + OpInfo::TensorProperties input_slice_properties; + NodeDef* input_slice; + TF_RETURN_IF_ERROR( + GetInputNode(pack->input(slice_start_value), &input_slice)); + TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value), + &input_slice_properties)); + PartialTensorShape input_slice_shape(input_slice_properties.shape()); + + OpInfo::TensorProperties output_properties; + TF_RETURN_IF_ERROR(GetTensorProperties( + strings::StrCat(node->name(), ":", 0), &output_properties)); + PartialTensorShape output_shape(output_properties.shape()); + NodeDef* output = + AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name()))); + if (input_slice_shape.IsCompatibleWith(output_shape)) { + output->set_op("Identity"); + output->set_device(node->device()); + SetDataTypeToAttr(output_properties.dtype(), "T", output); + output->add_input(input_slice->name()); + } else { + NodeDef* axis = AddEmptyNode( + OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis")); + axis->set_op("Const"); + axis->set_device(node->device()); + auto axis_attr = axis->mutable_attr(); + SetDataTypeToAttr(DT_INT32, "dtype", axis); + auto* axis_t = (*axis_attr)["value"].mutable_tensor(); + axis_t->set_dtype(DT_INT32); + axis_t->add_int_val(pack_axis); + AddToOptimizationQueue(axis); + output->set_op("ExpandDims"); + output->set_device(node->device()); + SetDataTypeToAttr(output_properties.dtype(), "T", output); + output->add_input(input_slice->name()); + output->add_input(axis->name()); + } + + // Copy dependencies over. + ForwardControlDependencies(output, {node, pack}); + AddToOptimizationQueue(output); + *simplified_node_name = output->name(); + + return Status::OK(); + } +}; + } // namespace class UniqueNodes { @@ -3132,7 +3419,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, graph_properties_.get(), node_map_.get(), - opt_level_); + &feed_nodes_, opt_level_); const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify); // Stop pipeline after first stage returning non-empty simplified tensor name. @@ -3186,6 +3473,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.unary_ops_composition) pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_stack_strided_slice_same_axis) + pipeline.AddStage(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); @@ -3249,6 +3538,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, optimized_graph_ = &optimized_item.graph; node_map_.reset(new NodeMap(optimized_graph_)); + for (const auto& feed : item.feed) { + feed_nodes_.insert(NodeName(feed.first)); + } + // Disable restricted graph rewrites. options_.unary_ops_composition &= item.allowed_optimizations.non_differentiable_rewrites; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index d457eb6d21..bb56f61e30 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -79,6 +80,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool convert_log1p = true; bool convert_expm1 = true; bool unary_ops_composition = true; + bool remove_stack_strided_slice_same_axis = false; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. @@ -128,6 +130,7 @@ class ArithmeticOptimizer : public GraphOptimizer { std::unique_ptr node_map_; std::unique_ptr graph_properties_; GraphDef* optimized_graph_ = nullptr; // Not owned. + gtl::FlatSet feed_nodes_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 77f3c64c65..d091b26b65 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -288,6 +288,12 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.unary_ops_composition = true; } + + void EnableOnlyRemoveStackStridedSliceSameAxis( + ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_stack_strided_slice_same_axis = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -3364,5 +3370,210 @@ TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } +TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto a_in = + ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + auto b_in = + ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2}); + auto c_in = + ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2}); + auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in, + PartialTensorShape({-1, -1})); + auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in, + PartialTensorShape({-1, -1})); + auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in, + PartialTensorShape({-1, -1})); + // stacked = tf.stack((a, b, c), axis=1). + // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1) + auto stacked = + ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output}, + ops::Stack::Axis(1)); + auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1}); + auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1}); + auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1}); + auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3}); + auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3}); + auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3}); + auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3}); + auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3}); + auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3}); + auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3}); + auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3}); + + // stacked[:, 0] + using SS = ops::StridedSlice; + auto pa_slice = ops::Identity( + s.WithOpName("pa_slice_out"), + SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0b0010))); // 2 + + // stacked[:, 1] + auto pb_slice = ops::Identity( + s.WithOpName("pb_slice_out"), + SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0b0010))); // 2 + + // stacked[:, 2] + auto pc_slice = ops::Identity( + s.WithOpName("pc_slice_out"), + SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0b0010))); // 2 + + // stacked[:, 0:1, :] + auto pa_slice_01 = ops::Identity( + s.WithOpName("pa_slice_01_out"), + SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0))); + + // stacked[:, :1, :] + auto pa_slice_to1 = ops::Identity( + s.WithOpName("pa_slice_to1_out"), + SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides, + SS::BeginMask(0b0111) // 7 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0))); + + // stacked[:, 1:2, :] + auto pb_slice_12 = ops::Identity( + s.WithOpName("pb_slice_12_out"), + SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0))); + + // stacked[:, 2:, :]. + auto pc_slice_2to = ops::Identity( + s.WithOpName("pc_slice_2to_out"), + SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0111) // 7 + .NewAxisMask(0) + .ShrinkAxisMask(0))); + + GrapplerItem item; + item.fetch = {"a", + "b", + "c", + "pa_slice_out", + "pb_slice_out", + "pc_slice_out", + "expanded_a", + "expanded_b", + "expanded_c", + "pa_slice_01_out", + "pa_slice_to1_out", + "pb_slice_12_out", + "pc_slice_2to_out"}; + enum FetchItem { + fA, + fB, + fC, + fASliceOut, + fBSliceOut, + fCSliceOut, + fExpandedA, + fExpandedB, + fExpandedC, + fASlice01Out, + fASliceTo1Out, + fBSlice12Out, + fCSlice2ToOut, + }; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + // stacked[:, 0, :] == a. + test::ExpectTensorEqual(tensors_expected[fA], + tensors_expected[fASliceOut]); + // stacked[:, 1, :] == b. + test::ExpectTensorEqual(tensors_expected[fB], + tensors_expected[fBSliceOut]); + // stacked[:, 2, :] == c. + test::ExpectTensorEqual(tensors_expected[fC], + tensors_expected[fCSliceOut]); + + // stacked[:, 0:1, :] == expand_dims(a, 1). + test::ExpectTensorEqual(tensors_expected[fExpandedA], + tensors_expected[fASlice01Out]); + + // stacked[:, :1, :] == expand_dims(a, 1). + test::ExpectTensorEqual(tensors_expected[fExpandedA], + tensors_expected[fASliceTo1Out]); + + // stacked[:, 1:2, :] == expand_dims(b, 1). + test::ExpectTensorEqual(tensors_expected[fExpandedB], + tensors_expected[fBSlice12Out]); + // stacked[:, 2:, :] == expand_dims(c, 1). + test::ExpectTensorEqual(tensors_expected[fExpandedC], + tensors_expected[fCSlice2ToOut]); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveStackStridedSliceSameAxis(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + + for (const auto& node : output.node()) { + if (node.name() == "pa_slice_out") { + EXPECT_EQ(node.input(0), "a"); + } else if (node.name() == "pb_slice_out") { + EXPECT_EQ(node.input(0), "b"); + } else if (node.name() == "pc_slice_out") { + EXPECT_EQ(node.input(0), "c"); + } else if (str_util::EndsWith(node.name(), "_out")) { + EXPECT_EQ(strings::StrCat(node.input(0), "_out"), + strings::StrCat( + "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_", + node.name())); + } + } + + auto tensors = EvaluateNodes(output, item.fetch); + + // stacked[:, 0, :] == a. + test::ExpectTensorEqual(tensors_expected[fA], tensors[fASliceOut]); + + // stacked[:, 1, :] == b. + test::ExpectTensorEqual(tensors_expected[fB], tensors[fBSliceOut]); + // stacked[:, 2, :] == c. + test::ExpectTensorEqual(tensors_expected[fC], tensors[fCSliceOut]); + + // stacked[:, 0:1, :] == expand_dims(a, 1). + test::ExpectTensorEqual(tensors_expected[fExpandedA], + tensors[fASlice01Out]); + + // stacked[:, :1, :] == expand_dims(a, 1). + test::ExpectTensorEqual(tensors_expected[fExpandedA], + tensors[fASliceTo1Out]); + + // stacked[:, 1:2, :] == expand_dims(b, 1). + test::ExpectTensorEqual(tensors_expected[fExpandedB], + tensors[fBSlice12Out]); + // stacked[:, 2:, :] == expand_dims(c, 1). + test::ExpectTensorEqual(tensors_expected[fExpandedC], + tensors[fCSlice2ToOut]); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index 2afb5df431..f31a30ec0e 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -46,17 +47,20 @@ struct GraphOptimizerContext { GraphOptimizerContext(const std::unordered_set* nodes_to_preserve, GraphDef* optimized_graph, GraphProperties* graph_properties, NodeMap* node_map, + gtl::FlatSet* feed_nodes, RewriterConfig::Toggle opt_level) : nodes_to_preserve(nodes_to_preserve), optimized_graph(optimized_graph), graph_properties(graph_properties), node_map(node_map), + feed_nodes(feed_nodes), opt_level(opt_level) {} const std::unordered_set* nodes_to_preserve; GraphDef* optimized_graph; GraphProperties* graph_properties; NodeMap* node_map; + gtl::FlatSet* feed_nodes; RewriterConfig::Toggle opt_level; }; diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc index 34f28c7c27..799c40c67b 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc @@ -61,6 +61,7 @@ TEST_F(GraphOptimizerStageTest, OptimizedNodeName) { /*optimized_graph*/ nullptr, /*graph_properties*/ nullptr, /*node_name*/ nullptr, + /*feed_nodes*/ nullptr, /*opt_level*/ RewriterConfig::ON); FakeOptimizerStage stage("my_opt", "my_stg", ctx); @@ -97,6 +98,7 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) { /*optimized_graph*/ &item.graph, /*graph_properties*/ &properties, /*node_name*/ &node_map, + /*feed_nodes*/ nullptr, /*opt_level*/ RewriterConfig::ON); FakeOptimizerStage stage("my_opt", "my_stg", ctx); @@ -137,6 +139,7 @@ TEST_F(GraphOptimizerStageTest, AddNodes) { /*optimized_graph*/ &item.graph, /*graph_properties*/ &properties, /*node_name*/ &node_map, + /*feed_nodes*/ nullptr, /*opt_level*/ RewriterConfig::ON); FakeOptimizerStage stage("my_opt", "my_stg", ctx); -- cgit v1.2.3