aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-12-22 14:10:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-22 14:14:55 -0800
commitbb453996b7e1cfd17e201d8abb88a84a6f8b3aa1 (patch)
tree4735606dee03a3a0ba361fb099aeb240ca539d72
parent3e27b27cf4d96bb52b39bcc17304cea8785d010a (diff)
Support non-constant param input of AvgPoolGrad and Sum.
PiperOrigin-RevId: 179962212
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc55
1 files changed, 10 insertions, 45 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index c9e5f842be..2786b8cf62 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -926,7 +926,7 @@ class AvgPoolGradProcessor : public NodeProcessor {
protected:
std::vector<int> GetInputPos() const override { return {1}; }
Status CustomizedProcessing() override {
- return UpdateAttrValueOfInput(0, true);
+ return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
}
};
@@ -1062,9 +1062,7 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor {
std::vector<int> GetInputPos() const override { return {2}; }
Status CustomizedProcessing() override {
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32));
- return Status::OK();
+ return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
}
};
@@ -1371,9 +1369,7 @@ class FillProcessor : public AgnosticNodeProcessor {
Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("index_type").type();
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype));
- return Status::OK();
+ return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype);
}
};
@@ -1470,9 +1466,7 @@ class PadProcessor : public AgnosticNodeProcessor {
protected:
Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("Tpaddings").type();
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype));
- return Status::OK();
+ return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
}
};
@@ -1484,9 +1478,7 @@ class ReverseProcessor : public AgnosticNodeProcessor {
protected:
Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("Tidx").type();
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
- return Status::OK();
+ return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
}
};
@@ -1511,9 +1503,8 @@ class SplitProcessor : public AgnosticNodeProcessor {
}
Status CustomizedProcessing() override {
- TF_RETURN_IF_ERROR(UpdateOrTransformParamInput(
- axis_node_pos_, "DataFormatDimMap", DT_INT32));
- return Status::OK();
+ return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
+ DT_INT32);
}
int axis_node_pos_;
@@ -1629,40 +1620,14 @@ class SumProcessor : public AgnosticNodeProcessor {
int port;
ParseNodeName(node_->input(0), &port);
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- IsPortDimsFour(*input0, port) && IsAlongDimNHW() && IsOnGPU();
+ IsPortDimsFour(*input0, port) && IsOnGPU();
}
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
Status CustomizedProcessing() override {
- return UpdateAttrValueOfInput(1, false);
- }
-
- private:
- bool IsAlongDimNHW() const {
- NodeDef* reduction_indices = node_map_->GetNode(node_->input(1));
- if (!IsConstant(*reduction_indices)) {
- return false;
- }
- Tensor tensor;
- if (reduction_indices->attr().find({"value"}) ==
- reduction_indices->attr().end()) {
- return false;
- }
- auto success =
- tensor.FromProto(reduction_indices->attr().at({"value"}).tensor());
- if (!success) {
- LOG(ERROR) << "Failed to parse TensorProto.";
- return false;
- }
- if (tensor.flat<int>().size() != 3) {
- return false;
- }
- if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 &&
- tensor.flat<int>()(2) == 2) {
- return true;
- }
- return false;
+ DataType dtype = node_->attr().at("Tidx").type();
+ return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
}
};