diff options
author | 2017-12-22 14:10:56 -0800 | |
---|---|---|
committer | 2017-12-22 14:14:55 -0800 | |
commit | bb453996b7e1cfd17e201d8abb88a84a6f8b3aa1 (patch) | |
tree | 4735606dee03a3a0ba361fb099aeb240ca539d72 | |
parent | 3e27b27cf4d96bb52b39bcc17304cea8785d010a (diff) |
Support non-constant param input of AvgPoolGrad and Sum.
PiperOrigin-RevId: 179962212
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 55 |
1 files changed, 10 insertions, 45 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index c9e5f842be..2786b8cf62 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -926,7 +926,7 @@ class AvgPoolGradProcessor : public NodeProcessor { protected: std::vector<int> GetInputPos() const override { return {1}; } Status CustomizedProcessing() override { - return UpdateAttrValueOfInput(0, true); + return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32); } }; @@ -1062,9 +1062,7 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor { std::vector<int> GetInputPos() const override { return {2}; } Status CustomizedProcessing() override { - TF_RETURN_IF_ERROR( - UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32)); - return Status::OK(); + return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32); } }; @@ -1371,9 +1369,7 @@ class FillProcessor : public AgnosticNodeProcessor { Status CustomizedProcessing() override { DataType dtype = node_->attr().at("index_type").type(); - TF_RETURN_IF_ERROR( - UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype)); - return Status::OK(); + return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype); } }; @@ -1470,9 +1466,7 @@ class PadProcessor : public AgnosticNodeProcessor { protected: Status CustomizedProcessing() override { DataType dtype = node_->attr().at("Tpaddings").type(); - TF_RETURN_IF_ERROR( - UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype)); - return Status::OK(); + return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype); } }; @@ -1484,9 +1478,7 @@ class ReverseProcessor : public AgnosticNodeProcessor { protected: Status CustomizedProcessing() override { DataType dtype = node_->attr().at("Tidx").type(); - TF_RETURN_IF_ERROR( - UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype)); - return Status::OK(); + return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype); } }; @@ -1511,9 +1503,8 @@ class SplitProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - TF_RETURN_IF_ERROR(UpdateOrTransformParamInput( - axis_node_pos_, "DataFormatDimMap", DT_INT32)); - return Status::OK(); + return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap", + DT_INT32); } int axis_node_pos_; @@ -1629,40 +1620,14 @@ class SumProcessor : public AgnosticNodeProcessor { int port; ParseNodeName(node_->input(0), &port); return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && - IsPortDimsFour(*input0, port) && IsAlongDimNHW() && IsOnGPU(); + IsPortDimsFour(*input0, port) && IsOnGPU(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); } Status CustomizedProcessing() override { - return UpdateAttrValueOfInput(1, false); - } - - private: - bool IsAlongDimNHW() const { - NodeDef* reduction_indices = node_map_->GetNode(node_->input(1)); - if (!IsConstant(*reduction_indices)) { - return false; - } - Tensor tensor; - if (reduction_indices->attr().find({"value"}) == - reduction_indices->attr().end()) { - return false; - } - auto success = - tensor.FromProto(reduction_indices->attr().at({"value"}).tensor()); - if (!success) { - LOG(ERROR) << "Failed to parse TensorProto."; - return false; - } - if (tensor.flat<int>().size() != 3) { - return false; - } - if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 && - tensor.flat<int>()(2) == 2) { - return true; - } - return false; + DataType dtype = node_->attr().at("Tidx").type(); + return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype); } }; |