From 79af30d357fbe0869e163e1d9dce0cb869b3724f Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 10 Oct 2018 08:36:36 -0700 Subject: [Grappler] Add RemoveStackStridedSliceSameAxis optimizer. // Replace operations of the form: // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...] // with // a_i // when the strided slice index `i` is applied in the k'th axis. // // Similarly, replace operations of the form: // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...] // with // expand_dims(a_i, axis=k) // PiperOrigin-RevId: 216535346 --- .../grappler/optimizers/arithmetic_optimizer.cc | 295 ++++++++++++++++++++- 1 file changed, 294 insertions(+), 1 deletion(-) (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc') diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 7d5014ee0a..0c2686a419 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/strided_slice_op.h" using tensorflow::strings::StrCat; @@ -157,6 +158,14 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node); } +Status CheckAttrExists(const NodeDef& node, const string& key) { + if (node.attr().count(key) == 0) { + return errors::InvalidArgument("Node '", node.name(), "'lacks '", key, + "' attr: ", node.DebugString()); + } + return Status::OK(); +} + NodeDef* GetTailOfValuePreservingChain( const NodeDef& node, const NodeMap& node_map, const std::unordered_set& nodes_to_preserve) { @@ -2902,6 +2911,284 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage { std::unordered_set fused_nodes_; }; +// Replace operations of the form: +// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...] +// with +// a_i +// when the strided slice index `i` is applied in the k'th axis. +// +// Similarly, replace operations of the form: +// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...] +// with +// expand_dims(a_i, axis=k) +// +// TODO(ebrevdo): Extend to also replace operations of the form +// concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...] +// with +// a_i, +// when +// s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i] +// and slicing is in the k'th axis. +class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage { + public: + explicit RemoveStackStridedSliceSameAxis( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx, + ctx_ext) {} + ~RemoveStackStridedSliceSameAxis() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsStridedSlice(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + // *node is a StridedSlice NodeDef. + NodeDef* pack; + + // Get the input and see if it's a Pack op. + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack)); + if (!IsPack(*pack)) return Status::OK(); + + bool return_early; + PartialTensorShape pack_output_shape; + int pack_axis; + TF_RETURN_IF_ERROR( + CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early)); + if (return_early) return Status::OK(); + + int slice_start_value; + bool found; + TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis, + &slice_start_value, &found)); + if (!found) return Status::OK(); + + return RewriteGraph(node, pack, slice_start_value, pack_axis, + simplified_node_name); + } + + protected: + bool IsReallyConstant(const NodeDef& node) const { + if (!IsConstant(node)) { + return false; + } + // If the node is fed it's not constant anymore. + return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end(); + } + + bool GetConstantAsInt64(const NodeDef& node, DataType dtype, + std::vector* values) { + if (dtype == DT_INT32) { + std::vector values_int32; + if (!ValuesFromConstNode(node, &values_int32)) { + return false; + } + std::copy(values_int32.begin(), values_int32.end(), + std::inserter(*values, values->begin())); + return true; + } else { + return ValuesFromConstNode(node, values); + } + } + + Status CheckInputs(const NodeDef* node, const NodeDef* pack, + PartialTensorShape* pack_output_shape, int* pack_axis, + bool* return_early) { + *return_early = true; + TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis")); + + *pack_axis = pack->attr().at("axis").i(); + auto slice_properties = + ctx().graph_properties->GetInputProperties(node->name()); + *pack_output_shape = slice_properties[0].shape(); + if (pack_output_shape->unknown_rank()) { + return Status::OK(); + } + const int pack_input_rank = pack_output_shape->dims() - 1; + if (*pack_axis < 0) { + // The ndims of any input into Pack op is its output ndims - 1. + *pack_axis += pack_input_rank; + } + if (*pack_axis < 0 || *pack_axis >= pack_input_rank) { + return errors::InvalidArgument( + "Pack node (", pack->name(), + ") axis attribute is out of bounds: ", pack->attr().at("axis").i()); + } + *return_early = false; + return Status::OK(); + } + + Status GetSliceAxis(const NodeDef* node, const NodeDef* pack, + const PartialTensorShape& pack_output_shape, + int pack_axis, int* slice_start_value, bool* found) { + *found = false; + for (auto key : {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", + "shrink_axis_mask"}) { + TF_RETURN_IF_ERROR(CheckAttrExists(*node, key)); + } + + const int begin_mask = node->attr().at("begin_mask").i(); + const int end_mask = node->attr().at("end_mask").i(); + const int ellipsis_mask = node->attr().at("ellipsis_mask").i(); + const int new_axis_mask = node->attr().at("new_axis_mask").i(); + const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i(); + + // Check that the StridedSlice is one of these at pack_axis: + // [..., i, ...] + // [..., i:i+1, ...] + // [..., :1, ...] + // [..., -1:, ...] + /// [..., s_{pack_axis}-1:, ...] + NodeDef* slice_begin; + NodeDef* slice_end; + NodeDef* slice_strides; + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides)); + + for (const auto* n : {slice_begin, slice_end, slice_strides}) { + if (!IsReallyConstant(*n)) return Status::OK(); + } + + Tensor slice_begin_t; + Tensor slice_end_t; + Tensor slice_strides_t; + + TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value")); + TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value")); + + if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) { + return Status::OK(); + } + if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) { + return Status::OK(); + } + if (!slice_strides_t.FromProto( + slice_strides->attr().at("value").tensor())) { + return Status::OK(); + } + TensorShape processing_shape; + TensorShape final_shape; + bool is_identity; + bool is_simple_slice; + bool slice_dim0; + gtl::InlinedVector slice_begin_vec; + gtl::InlinedVector slice_end_vec; + gtl::InlinedVector slice_strides_vec; + TF_RETURN_IF_ERROR(ValidateStridedSliceOp( + &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape, + begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, + &processing_shape, &final_shape, &is_identity, &is_simple_slice, + &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec)); + + if (!is_simple_slice) return Status::OK(); + + int begin_index = -1; + int64 begin_value = 0; + for (int i = 0; i < slice_begin_vec.size(); ++i) { + const int64 v = slice_begin_vec[i]; + if (v != 0) { + if (begin_index != -1) { + // At least two start values that are nonzero. + return Status::OK(); + } + begin_index = i; + begin_value = v; + } + } + + int end_index = -1; + int64 end_value = 0; + for (int i = 0; i < slice_end_vec.size(); ++i) { + const int64 v = slice_end_vec[i]; + if (v != pack_output_shape.dim_size(i)) { + if (end_index != -1) { + // At least two end values that are nonzero. + return Status::OK(); + } + end_index = i; + end_value = v; + } + } + + if (begin_index == -1 && end_index == -1) return Status::OK(); + if (begin_index != -1 && end_index != -1 && begin_index != end_index) { + // Somehow received different axes for begin/end slicing + return Status::OK(); + } + const int slice_axis = (begin_index == -1) ? end_index : begin_index; + if (slice_axis != pack_axis) { + // Not slicing on the same axis as the Pack op. + return Status::OK(); + } + *slice_start_value = (begin_index == -1) ? 0 : begin_value; + const int64 slice_end_value = + (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value; + if (slice_end_value != *slice_start_value + 1) { + // Not slicing a single value out. + return Status::OK(); + } + + if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) { + return errors::InvalidArgument( + "Node ", node->name(), " requested invalid slice index ", + *slice_start_value, " on axis ", slice_axis, + " from tensor of shape: ", pack_output_shape.DebugString()); + } + + *found = true; // slice_start_value is valid. + return Status::OK(); + } + + Status RewriteGraph(const NodeDef* node, const NodeDef* pack, + int slice_start_value, int pack_axis, + string* simplified_node_name) { + OpInfo::TensorProperties input_slice_properties; + NodeDef* input_slice; + TF_RETURN_IF_ERROR( + GetInputNode(pack->input(slice_start_value), &input_slice)); + TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value), + &input_slice_properties)); + PartialTensorShape input_slice_shape(input_slice_properties.shape()); + + OpInfo::TensorProperties output_properties; + TF_RETURN_IF_ERROR(GetTensorProperties( + strings::StrCat(node->name(), ":", 0), &output_properties)); + PartialTensorShape output_shape(output_properties.shape()); + NodeDef* output = + AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name()))); + if (input_slice_shape.IsCompatibleWith(output_shape)) { + output->set_op("Identity"); + output->set_device(node->device()); + SetDataTypeToAttr(output_properties.dtype(), "T", output); + output->add_input(input_slice->name()); + } else { + NodeDef* axis = AddEmptyNode( + OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis")); + axis->set_op("Const"); + axis->set_device(node->device()); + auto axis_attr = axis->mutable_attr(); + SetDataTypeToAttr(DT_INT32, "dtype", axis); + auto* axis_t = (*axis_attr)["value"].mutable_tensor(); + axis_t->set_dtype(DT_INT32); + axis_t->add_int_val(pack_axis); + AddToOptimizationQueue(axis); + output->set_op("ExpandDims"); + output->set_device(node->device()); + SetDataTypeToAttr(output_properties.dtype(), "T", output); + output->add_input(input_slice->name()); + output->add_input(axis->name()); + } + + // Copy dependencies over. + ForwardControlDependencies(output, {node, pack}); + AddToOptimizationQueue(output); + *simplified_node_name = output->name(); + + return Status::OK(); + } +}; + } // namespace class UniqueNodes { @@ -3132,7 +3419,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, graph_properties_.get(), node_map_.get(), - opt_level_); + &feed_nodes_, opt_level_); const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify); // Stop pipeline after first stage returning non-empty simplified tensor name. @@ -3186,6 +3473,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.unary_ops_composition) pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_stack_strided_slice_same_axis) + pipeline.AddStage(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); @@ -3249,6 +3538,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, optimized_graph_ = &optimized_item.graph; node_map_.reset(new NodeMap(optimized_graph_)); + for (const auto& feed : item.feed) { + feed_nodes_.insert(NodeName(feed.first)); + } + // Disable restricted graph rewrites. options_.unary_ops_composition &= item.allowed_optimizations.non_differentiable_rewrites; -- cgit v1.2.3