aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc295
1 files changed, 294 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 7d5014ee0a..0c2686a419 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/strided_slice_op.h"
using tensorflow::strings::StrCat;
@@ -157,6 +158,14 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node);
}
+Status CheckAttrExists(const NodeDef& node, const string& key) {
+ if (node.attr().count(key) == 0) {
+ return errors::InvalidArgument("Node '", node.name(), "'lacks '", key,
+ "' attr: ", node.DebugString());
+ }
+ return Status::OK();
+}
+
NodeDef* GetTailOfValuePreservingChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
@@ -2902,6 +2911,284 @@ class UnaryOpsComposition : public ArithmeticOptimizerStage {
std::unordered_set<string> fused_nodes_;
};
+// Replace operations of the form:
+// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
+// with
+// a_i
+// when the strided slice index `i` is applied in the k'th axis.
+//
+// Similarly, replace operations of the form:
+// x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
+// with
+// expand_dims(a_i, axis=k)
+//
+// TODO(ebrevdo): Extend to also replace operations of the form
+// concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
+// with
+// a_i,
+// when
+// s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
+// and slicing is in the k'th axis.
+class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
+ public:
+ explicit RemoveStackStridedSliceSameAxis(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
+ ctx_ext) {}
+ ~RemoveStackStridedSliceSameAxis() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsStridedSlice(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ // *node is a StridedSlice NodeDef.
+ NodeDef* pack;
+
+ // Get the input and see if it's a Pack op.
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
+ if (!IsPack(*pack)) return Status::OK();
+
+ bool return_early;
+ PartialTensorShape pack_output_shape;
+ int pack_axis;
+ TF_RETURN_IF_ERROR(
+ CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
+ if (return_early) return Status::OK();
+
+ int slice_start_value;
+ bool found;
+ TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
+ &slice_start_value, &found));
+ if (!found) return Status::OK();
+
+ return RewriteGraph(node, pack, slice_start_value, pack_axis,
+ simplified_node_name);
+ }
+
+ protected:
+ bool IsReallyConstant(const NodeDef& node) const {
+ if (!IsConstant(node)) {
+ return false;
+ }
+ // If the node is fed it's not constant anymore.
+ return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
+ }
+
+ bool GetConstantAsInt64(const NodeDef& node, DataType dtype,
+ std::vector<int64>* values) {
+ if (dtype == DT_INT32) {
+ std::vector<int32> values_int32;
+ if (!ValuesFromConstNode(node, &values_int32)) {
+ return false;
+ }
+ std::copy(values_int32.begin(), values_int32.end(),
+ std::inserter(*values, values->begin()));
+ return true;
+ } else {
+ return ValuesFromConstNode(node, values);
+ }
+ }
+
+ Status CheckInputs(const NodeDef* node, const NodeDef* pack,
+ PartialTensorShape* pack_output_shape, int* pack_axis,
+ bool* return_early) {
+ *return_early = true;
+ TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
+
+ *pack_axis = pack->attr().at("axis").i();
+ auto slice_properties =
+ ctx().graph_properties->GetInputProperties(node->name());
+ *pack_output_shape = slice_properties[0].shape();
+ if (pack_output_shape->unknown_rank()) {
+ return Status::OK();
+ }
+ const int pack_input_rank = pack_output_shape->dims() - 1;
+ if (*pack_axis < 0) {
+ // The ndims of any input into Pack op is its output ndims - 1.
+ *pack_axis += pack_input_rank;
+ }
+ if (*pack_axis < 0 || *pack_axis >= pack_input_rank) {
+ return errors::InvalidArgument(
+ "Pack node (", pack->name(),
+ ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
+ }
+ *return_early = false;
+ return Status::OK();
+ }
+
+ Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
+ const PartialTensorShape& pack_output_shape,
+ int pack_axis, int* slice_start_value, bool* found) {
+ *found = false;
+ for (auto key : {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask",
+ "shrink_axis_mask"}) {
+ TF_RETURN_IF_ERROR(CheckAttrExists(*node, key));
+ }
+
+ const int begin_mask = node->attr().at("begin_mask").i();
+ const int end_mask = node->attr().at("end_mask").i();
+ const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
+ const int new_axis_mask = node->attr().at("new_axis_mask").i();
+ const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
+
+ // Check that the StridedSlice is one of these at pack_axis:
+ // [..., i, ...]
+ // [..., i:i+1, ...]
+ // [..., :1, ...]
+ // [..., -1:, ...]
+ /// [..., s_{pack_axis}-1:, ...]
+ NodeDef* slice_begin;
+ NodeDef* slice_end;
+ NodeDef* slice_strides;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
+
+ for (const auto* n : {slice_begin, slice_end, slice_strides}) {
+ if (!IsReallyConstant(*n)) return Status::OK();
+ }
+
+ Tensor slice_begin_t;
+ Tensor slice_end_t;
+ Tensor slice_strides_t;
+
+ TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
+ TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
+
+ if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
+ return Status::OK();
+ }
+ if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
+ return Status::OK();
+ }
+ if (!slice_strides_t.FromProto(
+ slice_strides->attr().at("value").tensor())) {
+ return Status::OK();
+ }
+ TensorShape processing_shape;
+ TensorShape final_shape;
+ bool is_identity;
+ bool is_simple_slice;
+ bool slice_dim0;
+ gtl::InlinedVector<int64, 4> slice_begin_vec;
+ gtl::InlinedVector<int64, 4> slice_end_vec;
+ gtl::InlinedVector<int64, 4> slice_strides_vec;
+ TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
+ &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
+ begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
+ &processing_shape, &final_shape, &is_identity, &is_simple_slice,
+ &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
+
+ if (!is_simple_slice) return Status::OK();
+
+ int begin_index = -1;
+ int64 begin_value = 0;
+ for (int i = 0; i < slice_begin_vec.size(); ++i) {
+ const int64 v = slice_begin_vec[i];
+ if (v != 0) {
+ if (begin_index != -1) {
+ // At least two start values that are nonzero.
+ return Status::OK();
+ }
+ begin_index = i;
+ begin_value = v;
+ }
+ }
+
+ int end_index = -1;
+ int64 end_value = 0;
+ for (int i = 0; i < slice_end_vec.size(); ++i) {
+ const int64 v = slice_end_vec[i];
+ if (v != pack_output_shape.dim_size(i)) {
+ if (end_index != -1) {
+ // At least two end values that are nonzero.
+ return Status::OK();
+ }
+ end_index = i;
+ end_value = v;
+ }
+ }
+
+ if (begin_index == -1 && end_index == -1) return Status::OK();
+ if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
+ // Somehow received different axes for begin/end slicing
+ return Status::OK();
+ }
+ const int slice_axis = (begin_index == -1) ? end_index : begin_index;
+ if (slice_axis != pack_axis) {
+ // Not slicing on the same axis as the Pack op.
+ return Status::OK();
+ }
+ *slice_start_value = (begin_index == -1) ? 0 : begin_value;
+ const int64 slice_end_value =
+ (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
+ if (slice_end_value != *slice_start_value + 1) {
+ // Not slicing a single value out.
+ return Status::OK();
+ }
+
+ if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
+ return errors::InvalidArgument(
+ "Node ", node->name(), " requested invalid slice index ",
+ *slice_start_value, " on axis ", slice_axis,
+ " from tensor of shape: ", pack_output_shape.DebugString());
+ }
+
+ *found = true; // slice_start_value is valid.
+ return Status::OK();
+ }
+
+ Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
+ int slice_start_value, int pack_axis,
+ string* simplified_node_name) {
+ OpInfo::TensorProperties input_slice_properties;
+ NodeDef* input_slice;
+ TF_RETURN_IF_ERROR(
+ GetInputNode(pack->input(slice_start_value), &input_slice));
+ TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
+ &input_slice_properties));
+ PartialTensorShape input_slice_shape(input_slice_properties.shape());
+
+ OpInfo::TensorProperties output_properties;
+ TF_RETURN_IF_ERROR(GetTensorProperties(
+ strings::StrCat(node->name(), ":", 0), &output_properties));
+ PartialTensorShape output_shape(output_properties.shape());
+ NodeDef* output =
+ AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
+ if (input_slice_shape.IsCompatibleWith(output_shape)) {
+ output->set_op("Identity");
+ output->set_device(node->device());
+ SetDataTypeToAttr(output_properties.dtype(), "T", output);
+ output->add_input(input_slice->name());
+ } else {
+ NodeDef* axis = AddEmptyNode(
+ OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
+ axis->set_op("Const");
+ axis->set_device(node->device());
+ auto axis_attr = axis->mutable_attr();
+ SetDataTypeToAttr(DT_INT32, "dtype", axis);
+ auto* axis_t = (*axis_attr)["value"].mutable_tensor();
+ axis_t->set_dtype(DT_INT32);
+ axis_t->add_int_val(pack_axis);
+ AddToOptimizationQueue(axis);
+ output->set_op("ExpandDims");
+ output->set_device(node->device());
+ SetDataTypeToAttr(output_properties.dtype(), "T", output);
+ output->add_input(input_slice->name());
+ output->add_input(axis->name());
+ }
+
+ // Copy dependencies over.
+ ForwardControlDependencies(output, {node, pack});
+ AddToOptimizationQueue(output);
+ *simplified_node_name = output->name();
+
+ return Status::OK();
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -3132,7 +3419,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
graph_properties_.get(), node_map_.get(),
- opt_level_);
+ &feed_nodes_, opt_level_);
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
@@ -3186,6 +3473,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
if (options_.unary_ops_composition)
pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
+ if (options_.remove_stack_strided_slice_same_axis)
+ pipeline.AddStage<RemoveStackStridedSliceSameAxis>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
@@ -3249,6 +3538,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
+ for (const auto& feed : item.feed) {
+ feed_nodes_.insert(NodeName(feed.first));
+ }
+
// Disable restricted graph rewrites.
options_.unary_ops_composition &=
item.allowed_optimizations.non_differentiable_rewrites;