diff options
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 430 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer_test.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/kernels/constant_op.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops_test.cc | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/ops_test.py | 9 | ||||
-rw-r--r-- | tensorflow/python/framework/constant_op.py | 3 | ||||
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 72 | ||||
-rw-r--r-- | tensorflow/python/ops/parsing_ops.py | 4 |
12 files changed, 344 insertions, 240 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 6e7558c00e..e517159721 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -225,6 +225,8 @@ bool IsRestore(const NodeDef& node) { node.op() == "RestoreSlice"); } +bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; } + bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; } bool IsSelect(const NodeDef& node) { return node.op() == "Select"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index d9dc147a39..28d6ce52e1 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -89,6 +89,7 @@ bool IsRecv(const NodeDef& node); bool IsReduction(const NodeDef& node); bool IsReshape(const NodeDef& node); bool IsRestore(const NodeDef& node); +bool IsReverseV2(const NodeDef& node); bool IsRsqrtGrad(const NodeDef& node); bool IsSelect(const NodeDef& node); bool IsSeluGrad(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 59df49c245..62de52316e 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1486,7 +1486,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // TODO(rmlarsen): Handle non-associative/non-commutative operators like // subtraction and division, as well as mixed subtraction/addition, // division/multiplication. - if ((is_add || is_mul) && NumNonControlInputs(*node) == 2) { + if ((IsAdd(*node) || is_mul) && NumNonControlInputs(*node) == 2) { NodeDef* left_child = node_map_->GetNode(node->input(0)); NodeDef* right_child = node_map_->GetNode(node->input(1)); // One child must be constant, and the other the same op as the parent. diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 6bedcc4b36..3b52728b3b 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -41,6 +41,8 @@ const char kPermNHWCToNCHW[] = "LayoutOptimizerPermConstNHWCToNCHW"; const char kPermNCHWToNHWC[] = "LayoutOptimizerPermConstNCHWToNHWC"; const char kTransposeNHWCToNCHW[] = "LayoutOptimizerTransposeNHWCToNCHW"; const char kTransposeNCHWToNHWC[] = "LayoutOptimizerTransposeNCHWToNHWC"; +const char kDimMapNHWCToNCHW[] = "LayoutOptimizerDimMapNHWCToNCHW"; +const char kDimMapNCHWToNHWC[] = "LayoutOptimizerDimMapNCHWToNHWC"; const char kVecPermuteNHWCToNCHW[] = "LayoutOptimizerVecPermuteNHWCToNCHW"; const char kVecPermuteNCHWToNHWC[] = "LayoutOptimizerVecPermuteNCHWToNHWC"; const char kReshapeNHWCToNCHW[] = "LayoutOptimizerReshapeNHWCToNCHW"; @@ -179,6 +181,7 @@ std::set<string> GetOpsFormatAgnostic() { "Switch", "TruncateDiv", "TruncateMod", + "ReverseV2", "Round", "Rsqrt", "RsqrtGrad", @@ -453,12 +456,7 @@ class NodeProcessor : public GraphProcessor { return nodes_to_preserve_.find(node_->name()) != nodes_to_preserve_.end(); } - virtual bool ShouldProcess() const { - return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) && - HasOutputs() && IsOnGPU(); - } - - virtual bool IsOnGPU() const { + bool IsOnGPU() const { string device_name; if (node_->device().empty()) { device_name = virtual_placer_.get_canonical_device_name(*node_); @@ -475,14 +473,9 @@ class NodeProcessor : public GraphProcessor { return false; } - void UpdateAttrDataFormat() { - if (node_->attr().find("data_format") != node_->attr().end()) { - if (node_->attr().at("data_format").s().compare("NHWC") == 0) { - string* data_format = - node_->mutable_attr()->at("data_format").mutable_s(); - *data_format = "NCHW"; - } - } + virtual bool ShouldProcess() const { + return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) && + HasOutputs() && IsOnGPU(); } virtual void UpdateAttrShape() { @@ -504,74 +497,7 @@ class NodeProcessor : public GraphProcessor { } } - void UpdateAttrKSize() { - if (node_->attr().find("ksize") != node_->attr().end()) { - auto list = node_->mutable_attr()->at("ksize").mutable_list(); - UpdateTuple(list); - } - } - - void UpdateAttrStrides() { - if (node_->attr().find("strides") != node_->attr().end()) { - auto list = node_->mutable_attr()->at("strides").mutable_list(); - UpdateTuple(list); - } - } - - Status UpdateAttrValue(NodeDef* node) { - TF_RETURN_IF_ERROR(HasAttribute(*node, "value")); - Tensor tensor; - auto success = - tensor.FromProto(node->mutable_attr()->at({"value"}).tensor()); - if (!success) { - LOG(ERROR) << "Failed to parse TensorProto."; - } - if (tensor.dims() == 0) { - int value = tensor.scalar<int>()(); - value = (value >= 0) ? value : value + 4; - if (value == 1 || value == 2) { - value = value + 1; - } else if (value == 3) { - value = 1; - } - tensor.scalar<int>()() = value; - } else if (tensor.dims() == 1) { - if (tensor.flat<int>().size() == 4) { - int c = tensor.flat<int>()(3); - tensor.flat<int>()(3) = tensor.flat<int>()(2); - tensor.flat<int>()(2) = tensor.flat<int>()(1); - tensor.flat<int>()(1) = c; - } else if (tensor.flat<int>().size() == 3) { - tensor.flat<int>()(0) = 0; - tensor.flat<int>()(1) = 2; - tensor.flat<int>()(2) = 3; - } else { - return Status(error::INVALID_ARGUMENT, - strings::StrCat("Unsupported tensor size: ", - tensor.flat<int>().size())); - } - } else if (tensor.dims() == 2) { - for (int i = 0; i < 2; i++) { - int c = tensor.matrix<int>()(3, i); - tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i); - tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i); - tensor.matrix<int>()(1, i) = c; - } - } else { - return Status( - error::INVALID_ARGUMENT, - strings::StrCat("Unsupported dimension size: ", tensor.dims())); - } - if (tensor.dims() == 0) { - tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor()); - } else { - tensor.AsProtoTensorContent( - node->mutable_attr()->at({"value"}).mutable_tensor()); - } - return Status::OK(); - } - - Status UpdateAttrValueOfInput(int input_index) { + Status UpdateAttrValueOfInput(int input_index, bool permute) { auto input_node = node_map_->GetNode(node_->input(input_index)); // We created a copy of the node, so that we don't modify the original node, // which might be used elsewhere. Note that this copy also copies the @@ -585,7 +511,7 @@ class NodeProcessor : public GraphProcessor { *node_->mutable_input(input_index) = node_name; node_map_->AddNode(node_name, added_node); node_map_->AddOutput(node_name, node_->name()); - return UpdateAttrValue(added_node); + return UpdateAttrValue(added_node, permute); } virtual std::vector<int> GetInputPos() const { @@ -601,42 +527,6 @@ class NodeProcessor : public GraphProcessor { return output_pos; } - NodeDef* AddNodeTranspose(const string& node_name, const string& input_name, - const string& const_name, DataType data_type, - const TensorShapeProto& input_shape, - bool NHWCToNCHW) { - NodeDef* node = graph_->add_node(); - node_map_->AddNode(node_name, node); - node->set_name(node_name); - *node->add_input() = input_name; - *node->add_input() = const_name; - node->set_op("Transpose"); - node->set_device(node_->device()); - AttrValue attr_data_type; - attr_data_type.set_type(data_type); - node->mutable_attr()->insert({"T", attr_data_type}); - AttrValue attr_data_type_perm; - attr_data_type_perm.set_type(DT_INT32); - node->mutable_attr()->insert({"Tperm", attr_data_type_perm}); - if (!input_shape.unknown_rank()) { - AttrValue attr_output_shape; - auto output_shape = attr_output_shape.mutable_list()->add_shape(); - if (NHWCToNCHW) { - output_shape->add_dim()->set_size(input_shape.dim(0).size()); - output_shape->add_dim()->set_size(input_shape.dim(3).size()); - output_shape->add_dim()->set_size(input_shape.dim(1).size()); - output_shape->add_dim()->set_size(input_shape.dim(2).size()); - } else { - output_shape->add_dim()->set_size(input_shape.dim(0).size()); - output_shape->add_dim()->set_size(input_shape.dim(2).size()); - output_shape->add_dim()->set_size(input_shape.dim(3).size()); - output_shape->add_dim()->set_size(input_shape.dim(1).size()); - } - node->mutable_attr()->insert({"_output_shapes", attr_output_shape}); - } - return node; - } - virtual Status AddLayoutTransposeToInputs() { std::vector<int> input_pos = GetInputPos(); for (const auto& pos : input_pos) { @@ -732,6 +622,137 @@ class NodeProcessor : public GraphProcessor { virtual Status CustomizedProcessing() { return Status::OK(); } + Status UpdateOrTransformParamInput(int param_index, const string& op, + DataType dtype) { + auto param_node = node_map_->GetNode(node_->input(param_index)); + bool permute = (op == "DataFormatVecPermute") ? true : false; + if (IsConstant(*param_node)) { + TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(param_index, permute)); + } else { + AddDataFormatTranformToParamInput(op, param_index, dtype); + } + return Status::OK(); + } + + NodeDef* node_; + bool is_in_frame_; + + private: + void UpdateAttrKSize() { + if (node_->attr().find("ksize") != node_->attr().end()) { + auto list = node_->mutable_attr()->at("ksize").mutable_list(); + UpdateTuple(list); + } + } + + void UpdateAttrStrides() { + if (node_->attr().find("strides") != node_->attr().end()) { + auto list = node_->mutable_attr()->at("strides").mutable_list(); + UpdateTuple(list); + } + } + + void UpdateAttrDataFormat() { + if (node_->attr().find("data_format") != node_->attr().end()) { + if (node_->attr().at("data_format").s().compare("NHWC") == 0) { + string* data_format = + node_->mutable_attr()->at("data_format").mutable_s(); + *data_format = "NCHW"; + } + } + } + + Status UpdateAttrValue(NodeDef* node, bool permute) { + TF_RETURN_IF_ERROR(HasAttribute(*node, "value")); + Tensor tensor; + auto success = + tensor.FromProto(node->mutable_attr()->at({"value"}).tensor()); + if (!success) { + LOG(ERROR) << "Failed to parse TensorProto."; + } + + if (permute) { + if (tensor.dims() == 1) { + if (tensor.flat<int>().size() == 4) { + int c = tensor.flat<int>()(3); + tensor.flat<int>()(3) = tensor.flat<int>()(2); + tensor.flat<int>()(2) = tensor.flat<int>()(1); + tensor.flat<int>()(1) = c; + } else { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("Unsupported tensor size: ", + tensor.flat<int>().size())); + } + } else if (tensor.dims() == 2) { + for (int i = 0; i < 2; i++) { + int c = tensor.matrix<int>()(3, i); + tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i); + tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i); + tensor.matrix<int>()(1, i) = c; + } + } else { + return Status( + error::INVALID_ARGUMENT, + strings::StrCat("Unsupported dimension size: ", tensor.dims())); + } + } else { + for (int i = 0; i < tensor.flat<int>().size(); i++) { + int value = tensor.flat<int>()(i); + value = (value >= 0) ? value : value + 4; + if (value == 1 || value == 2) { + value = value + 1; + } else if (value == 3) { + value = 1; + } + tensor.flat<int>()(i) = value; + } + } + + if (tensor.dims() == 0) { + tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor()); + } else { + tensor.AsProtoTensorContent( + node->mutable_attr()->at({"value"}).mutable_tensor()); + } + return Status::OK(); + } + + NodeDef* AddNodeTranspose(const string& node_name, const string& input_name, + const string& const_name, DataType data_type, + const TensorShapeProto& input_shape, + bool NHWCToNCHW) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(node_name, node); + node->set_name(node_name); + *node->add_input() = input_name; + *node->add_input() = const_name; + node->set_op("Transpose"); + node->set_device(node_->device()); + AttrValue attr_data_type; + attr_data_type.set_type(data_type); + node->mutable_attr()->insert({"T", attr_data_type}); + AttrValue attr_data_type_perm; + attr_data_type_perm.set_type(DT_INT32); + node->mutable_attr()->insert({"Tperm", attr_data_type_perm}); + if (!input_shape.unknown_rank()) { + AttrValue attr_output_shape; + auto output_shape = attr_output_shape.mutable_list()->add_shape(); + if (NHWCToNCHW) { + output_shape->add_dim()->set_size(input_shape.dim(0).size()); + output_shape->add_dim()->set_size(input_shape.dim(3).size()); + output_shape->add_dim()->set_size(input_shape.dim(1).size()); + output_shape->add_dim()->set_size(input_shape.dim(2).size()); + } else { + output_shape->add_dim()->set_size(input_shape.dim(0).size()); + output_shape->add_dim()->set_size(input_shape.dim(2).size()); + output_shape->add_dim()->set_size(input_shape.dim(3).size()); + output_shape->add_dim()->set_size(input_shape.dim(1).size()); + } + node->mutable_attr()->insert({"_output_shapes", attr_output_shape}); + } + return node; + } + NodeDef* AddNodePermNHWCToNCHW(const string& suffix, const string& depended_node, const string& device) { @@ -754,44 +775,6 @@ class NodeProcessor : public GraphProcessor { return const_node; } - NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name, - const string& op, DataType dtype, - bool nhwc_to_nchw) { - NodeDef* added_node = graph_->add_node(); - added_node->set_name(name); - added_node->set_op(op); - node_map_->AddNode(added_node->name(), added_node); - added_node->set_device(node_->device()); - AttrValue attr_data_type; - attr_data_type.set_type(dtype); - added_node->mutable_attr()->insert({"T", attr_data_type}); - string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW"; - string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC"; - AttrValue attr_format; - attr_format.set_s(src_format); - added_node->mutable_attr()->insert({"src_format", attr_format}); - attr_format.set_s(dst_format); - added_node->mutable_attr()->insert({"dst_format", attr_format}); - *added_node->add_input() = input_name; - return added_node; - } - - void AddDataFormatTranformToInput(const string& op, int input_pos, - DataType dtype) { - string name = strings::StrCat(kVecPermuteNHWCToNCHW, "_", node_->name(), - "_", input_pos); - auto added_node = - AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true); - *node_->mutable_input(input_pos) = added_node->name(); - node_map_->UpdateOutput(added_node->input(0), node_->name(), - added_node->name()); - node_map_->AddOutput(added_node->name(), node_->name()); - } - - NodeDef* node_; - bool is_in_frame_; - - private: string GetOrAddNodePermNHWCToNCHW(int pos) { string const_name; if (is_in_frame_) { @@ -833,6 +816,41 @@ class NodeProcessor : public GraphProcessor { list->set_i(2, h); list->set_i(3, w); } + + NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name, + const string& op, DataType dtype, + bool nhwc_to_nchw) { + NodeDef* added_node = graph_->add_node(); + added_node->set_name(name); + added_node->set_op(op); + node_map_->AddNode(added_node->name(), added_node); + added_node->set_device(node_->device()); + AttrValue attr_data_type; + attr_data_type.set_type(dtype); + added_node->mutable_attr()->insert({"T", attr_data_type}); + string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW"; + string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC"; + AttrValue attr_format; + attr_format.set_s(src_format); + added_node->mutable_attr()->insert({"src_format", attr_format}); + attr_format.set_s(dst_format); + added_node->mutable_attr()->insert({"dst_format", attr_format}); + *added_node->add_input() = input_name; + return added_node; + } + + void AddDataFormatTranformToParamInput(const string& op, int input_pos, + DataType dtype) { + string prefix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW + : kDimMapNHWCToNCHW; + string name = strings::StrCat(prefix, "_", node_->name(), "_", input_pos); + auto added_node = + AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true); + *node_->mutable_input(input_pos) = added_node->name(); + node_map_->UpdateOutput(added_node->input(0), node_->name(), + added_node->name()); + node_map_->AddOutput(added_node->name(), node_->name()); + } }; class AvgPoolGradProcessor : public NodeProcessor { @@ -845,7 +863,9 @@ class AvgPoolGradProcessor : public NodeProcessor { std::vector<int> input_pos = {1}; return input_pos; } - Status CustomizedProcessing() override { return UpdateAttrValueOfInput(0); } + Status CustomizedProcessing() override { + return UpdateAttrValueOfInput(0, true); + } }; class BiasAddGradProcessor : public NodeProcessor { @@ -986,12 +1006,8 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor { } Status CustomizedProcessing() override { - auto input_size_node = node_map_->GetNode(node_->input(0)); - if (IsConstant(*input_size_node)) { - TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(0)); - } else { - AddDataFormatTranformToInput("DataFormatVecPermute", 0, DT_INT32); - } + TF_RETURN_IF_ERROR( + UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32)); return Status::OK(); } }; @@ -1042,12 +1058,8 @@ class MaxPoolGradV2Processor : public MaxPoolGradProcessor { protected: Status CustomizedProcessing() override { for (int i = 3; i < node_->input_size(); i++) { - auto param_node = node_map_->GetNode(node_->input(i)); - if (IsConstant(*param_node)) { - TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(i)); - } else { - AddDataFormatTranformToInput("DataFormatVecPermute", i, DT_INT32); - } + TF_RETURN_IF_ERROR( + UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32)); } return Status::OK(); } @@ -1072,12 +1084,8 @@ class MaxPoolV2Processor : public NodeProcessor { Status CustomizedProcessing() override { for (int i = 1; i < node_->input_size(); i++) { - auto param_node = node_map_->GetNode(node_->input(i)); - if (IsConstant(*param_node)) { - TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(i)); - } else { - AddDataFormatTranformToInput("DataFormatVecPermute", i, DT_INT32); - } + TF_RETURN_IF_ERROR( + UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32)); } return Status::OK(); } @@ -1325,17 +1333,12 @@ class ConcatProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - auto dim_node = node_map_->GetNode(node_->input(axis_node_pos_)); - if (IsConstant(*dim_node)) { - TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(axis_node_pos_)); - } else { - DataType dtype = (IsSplit(*node_) || IsSplitV(*node_)) - ? DT_INT32 - : node_->attr().at("Tidx").type(); - AddDataFormatTranformToInput("DataFormatDimMap", axis_node_pos_, dtype); - } + DataType dtype = node_->attr().at("Tidx").type(); + TF_RETURN_IF_ERROR( + UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap", dtype)); return Status::OK(); } + int axis_node_pos_; }; @@ -1414,29 +1417,36 @@ class PadProcessor : public AgnosticNodeProcessor { protected: Status CustomizedProcessing() override { - auto index_node = node_map_->GetNode(node_->input(1)); - if (IsConstant(*index_node)) { - TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(1)); - } else { - DataType dtype = node_->attr().at("Tpaddings").type(); - AddDataFormatTranformToInput("DataFormatVecPermute", 1, dtype); - } + DataType dtype = node_->attr().at("Tpaddings").type(); + TF_RETURN_IF_ERROR( + UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype)); return Status::OK(); } }; -class SplitProcessor : public ConcatProcessor { +class ReverseProcessor : public AgnosticNodeProcessor { + public: + explicit ReverseProcessor(const OptimizeContext& opt_cxt) + : AgnosticNodeProcessor(opt_cxt) {} + + protected: + Status CustomizedProcessing() override { + DataType dtype = node_->attr().at("Tidx").type(); + TF_RETURN_IF_ERROR( + UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype)); + return Status::OK(); + } +}; + +class SplitProcessor : public AgnosticNodeProcessor { public: explicit SplitProcessor(const OptimizeContext& opt_cxt) - : ConcatProcessor(opt_cxt) { + : AgnosticNodeProcessor(opt_cxt) { axis_node_pos_ = 0; } protected: - std::vector<int> GetInputPos() const override { - std::vector<int> input_pos = {1}; - return input_pos; - } + std::vector<int> GetInputPos() const override { return {1}; } std::set<int> GetOutputPos() const override { std::set<int> output_pos{0}; @@ -1447,6 +1457,14 @@ class SplitProcessor : public ConcatProcessor { } return output_pos; } + + Status CustomizedProcessing() override { + TF_RETURN_IF_ERROR(UpdateOrTransformParamInput( + axis_node_pos_, "DataFormatDimMap", DT_INT32)); + return Status::OK(); + } + + int axis_node_pos_; }; class SplitVProcessor : public SplitProcessor { @@ -1533,13 +1551,9 @@ class SliceProcessor : public AgnosticNodeProcessor { Status CustomizedProcessing() override { // Skip the first input, which is the data to be sliced. for (int i = 1; i < node_->input_size(); i++) { - auto index_node = node_map_->GetNode(node_->input(i)); - if (IsConstant(*index_node)) { - TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(i)); - } else { - AddDataFormatTranformToInput("DataFormatVecPermute", i, - node_->attr().at("Index").type()); - } + DataType dtype = node_->attr().at("Index").type(); + TF_RETURN_IF_ERROR( + UpdateOrTransformParamInput(i, "DataFormatVecPermute", dtype)); } return Status::OK(); } @@ -1614,7 +1628,9 @@ class SumProcessor : public AgnosticNodeProcessor { Status AddLayoutTransposeToOutputs() override { return Status::OK(); } - Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); } + Status CustomizedProcessing() override { + return UpdateAttrValueOfInput(1, false); + } private: bool IsAlongDimNHW() const { @@ -1765,6 +1781,8 @@ class DataLayoutOptimizer : GraphProcessor { node_processor.reset(new MergeProcessor(opt_cxt)); } else if (IsPad(*node)) { node_processor.reset(new PadProcessor(opt_cxt)); + } else if (IsReverseV2(*node)) { + node_processor.reset(new ReverseProcessor(opt_cxt)); } else if (IsSlice(*node)) { node_processor.reset(new SliceProcessor(opt_cxt)); } else if (IsShape(*node) || IsShapeN(*node)) { diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index 781f942532..161cf7ad96 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -559,11 +559,9 @@ TEST_F(LayoutOptimizerTest, SplitNonConstDim) { Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); NodeMap node_map(&output); auto split_node = node_map.GetNode("split"); - EXPECT_EQ(split_node->input(0), - "LayoutOptimizerVecPermuteNHWCToNCHW_split_0"); + EXPECT_EQ(split_node->input(0), "LayoutOptimizerDimMapNHWCToNCHW_split_0"); EXPECT_EQ(split_node->input(1), "Conv2D"); - auto map_node = - node_map.GetNode("LayoutOptimizerVecPermuteNHWCToNCHW_split_0"); + auto map_node = node_map.GetNode("LayoutOptimizerDimMapNHWCToNCHW_split_0"); EXPECT_EQ(map_node->op(), "DataFormatDimMap"); EXPECT_EQ(map_node->input(0), "i1"); } @@ -629,10 +627,9 @@ TEST_F(LayoutOptimizerTest, ConcatNonConst) { auto concat_node = node_map.GetNode("concat"); EXPECT_EQ(concat_node->input(0), "split"); EXPECT_EQ(concat_node->input(1), "split:1"); - EXPECT_EQ(concat_node->input(2), - "LayoutOptimizerVecPermuteNHWCToNCHW_concat_2"); + EXPECT_EQ(concat_node->input(2), "LayoutOptimizerDimMapNHWCToNCHW_concat_2"); auto concat_dim = - node_map.GetNode("LayoutOptimizerVecPermuteNHWCToNCHW_concat_2"); + node_map.GetNode("LayoutOptimizerDimMapNHWCToNCHW_concat_2"); EXPECT_EQ(concat_dim->op(), "DataFormatDimMap"); EXPECT_EQ(concat_dim->input(0), "i"); } diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 103a0e225e..c485d02e23 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -164,25 +164,24 @@ struct FillFunctor<CPUDevice, T> { } // end namespace functor -template <typename Device, typename T> +template <typename Device, typename T, typename Index> class FillOp : public OpKernel { public: explicit FillOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { const Tensor& Tdims = context->input(0); - OP_REQUIRES( - context, IsLegacyVector(Tdims.shape()), - errors::InvalidArgument("dims must be a vector of int32, got shape ", - Tdims.shape().DebugString())); + OP_REQUIRES(context, IsLegacyVector(Tdims.shape()), + errors::InvalidArgument("dims must be a vector, got shape ", + Tdims.shape().DebugString())); const Tensor& Tvalue = context->input(1); OP_REQUIRES(context, IsLegacyScalar(Tvalue.shape()), errors::InvalidArgument("value must be a scalar, got shape ", Tvalue.shape().DebugString())); - auto dims = Tdims.flat<int32>(); + auto dims = Tdims.flat<Index>(); TensorShape shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - reinterpret_cast<const int32*>(dims.data()), + reinterpret_cast<const Index*>(dims.data()), dims.size(), &shape)); Tensor* out = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, shape, &out)); @@ -214,12 +213,19 @@ struct FillFunctor<SYCLDevice, T> { } // namespace functor #endif // TENSORFLOW_USE_SYCL -#define REGISTER_KERNEL(D, TYPE) \ - REGISTER_KERNEL_BUILDER(Name("Fill") \ - .Device(DEVICE_##D) \ - .TypeConstraint<TYPE>("T") \ - .HostMemory("dims"), \ - FillOp<D##Device, TYPE>); +#define REGISTER_KERNEL(D, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Fill") \ + .Device(DEVICE_##D) \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint<int32>("index_type") \ + .HostMemory("dims"), \ + FillOp<D##Device, TYPE, int32>); \ + REGISTER_KERNEL_BUILDER(Name("Fill") \ + .Device(DEVICE_##D) \ + .TypeConstraint<TYPE>("T") \ + .TypeConstraint<int64>("index_type") \ + .HostMemory("dims"), \ + FillOp<D##Device, TYPE, int64>); #define REGISTER_CPU_KERNEL(TYPE) REGISTER_KERNEL(CPU, TYPE) TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL); @@ -241,10 +247,11 @@ REGISTER_KERNEL(SYCL, int64); REGISTER_KERNEL_BUILDER(Name("Fill") .Device(DEVICE_SYCL) .TypeConstraint<int32>("T") + .TypeConstraint<int32>("index_type") .HostMemory("dims") .HostMemory("value") .HostMemory("output"), - FillOp<CPUDevice, int32>); + FillOp<CPUDevice, int32, int32>); #undef REGISTER_KERNEL_SYCL #endif // TENSORFLOW_USE_SYCL @@ -267,10 +274,11 @@ REGISTER_KERNEL(GPU, bool); REGISTER_KERNEL_BUILDER(Name("Fill") .Device(DEVICE_GPU) .TypeConstraint<int32>("T") + .TypeConstraint<int32>("index_type") .HostMemory("dims") .HostMemory("value") .HostMemory("output"), - FillOp<CPUDevice, int32>); + FillOp<CPUDevice, int32, int32>); #endif #undef REGISTER_KERNEL diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 5a31f433ce..c7d9b97461 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1328,11 +1328,17 @@ The output will be: // -------------------------------------------------------------------------- REGISTER_OP("Fill") - .Input("dims: int32") + .Input("dims: index_type") .Input("value: T") .Output("output: T") .Attr("T: type") + .Attr("index_type: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { + DataType index_type = DT_INT32; + Status s = c->GetAttr("index_type", &index_type); + if (!s.ok() && s.code() != error::NOT_FOUND) { + return s; + } ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); @@ -1340,7 +1346,8 @@ REGISTER_OP("Fill") const Tensor* t = c->input_tensor(0); if (t != nullptr) { for (int i = 0; i < t->NumElements(); ++i) { - if (t->vec<int32>()(i) < 0) { + if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) || + (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) { return errors::InvalidArgument("Fill dimensions must be >= 0"); } } diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index c8ea443613..a182fd1c47 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -253,6 +253,7 @@ TEST(ArrayOpsTest, ReverseV2_ShapeFn) { TEST(ArrayOpsTest, Fill_ShapeFn) { ShapeInferenceTestOp op("Fill"); + AddNodeAttr("index_type", DT_INT32, &op.node_def); op.input_tensors.resize(2); INFER_OK(op, "?;?", "?"); INFER_OK(op, "[?];?", "?"); diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index 48dcb4830c..f8c5037dcf 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -255,11 +255,12 @@ class OpsTest(test_util.TensorFlowTestCase): 'using.*DEVICE_PLACEMENT_SILENT'): reshaped = array_ops.reshape(value, shape.gpu()) - def testInvalidInputDataType(self): + def testInt64(self): # Fill requires the first input to be an int32 tensor. - with self.assertRaisesRegexp(errors.InvalidArgumentError, 'int64'): - array_ops.fill(constant_op.constant([2], dtype=dtypes.int64), - constant_op.constant(1)) + self.assertAllEqual( + [1.0, 1.0], + array_ops.fill(constant_op.constant([2], dtype=dtypes.int64), + constant_op.constant(1))) def testOutputOnHostMemory(self): if not context.context().num_gpus(): diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index bf3be34d85..ac915157f5 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -45,6 +45,7 @@ import numpy as np import six from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import types_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.framework import dtypes @@ -71,7 +72,7 @@ def _eager_fill(dims, value, ctx): attr_t = value.dtype.as_datatype_enum dims = convert_to_eager_tensor(dims, ctx, dtypes.int32) inputs_flat = [dims, value] - attrs = ("T", attr_t) + attrs = ("T", attr_t, "index_type", types_pb2.DT_INT32) result, = execute.execute( b"Fill", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx) return result diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 2c15e340c9..909981982d 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -279,7 +279,7 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(expected_num_transposes, num_transposes) self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes) self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-split-0-0', nodes) - self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_split_0', nodes) + self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_split_0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testSplitVWithNonConstAxis(self): @@ -313,7 +313,7 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(expected_num_transposes, num_transposes) self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes) self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-SplitV-0-0', nodes) - self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_SplitV_2', nodes) + self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_SplitV_2', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testPadWithConstPaddings(self): @@ -350,6 +350,74 @@ class LayoutOptimizerTest(test.TestCase): self.assertIn('LayoutOptimizer-Pad-PaddingsConst', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testReverseWithConstDims(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + dims = constant_op.constant([3, 1], name='DimsConst') + reverse = array_ops.reverse(conv, dims) + output = array_ops.identity(reverse) + + with session.Session() as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if node.name.startswith('LayoutOptimizerTranspose'): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes) + self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-ReverseV2-0-0', nodes) + self.assertIn('LayoutOptimizer-ReverseV2-DimsConst', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testReverseWithNonConstDims(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + dims = array_ops.placeholder(dtype='int32') + reverse = array_ops.reverse(conv, dims) + output = array_ops.identity(reverse) + + dims_val = [2, 3] + with session.Session() as sess: + output_val_ref = sess.run(output, feed_dict={dims: dims_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={ + dims: dims_val + }) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if node.name.startswith('LayoutOptimizerTranspose'): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes) + self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-ReverseV2-0-0', nodes) + self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_ReverseV2_1', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testTernaryOp(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index eba40c4f85..df1f00ddcd 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -385,7 +385,7 @@ def parse_example(serialized, features, name=None, example_names=None): A `values[i]` comes from a position `k` in the feature of an example at batch entry `batch`. This positional information is recorded in `indices[i]` as `[batch, index_0, index_1, ...]` where `index_j` is the `k-th` value of - the feature in the example at with key `SparseFeature.index_key[j]. + the feature in the example at with key `SparseFeature.index_key[j]`. In other words, we split the indices (except the first index indicating the batch entry) of a `SparseTensor` by dimension into different features of the `Example`. Due to its complexity a `VarLenFeature` should be preferred over a @@ -1234,7 +1234,7 @@ def parse_single_example_v2(serialized, features, name=None): A `values[i]` comes from a position `k` in the feature of an example at batch entry `batch`. This positional information is recorded in `indices[i]` as `[batch, index_0, index_1, ...]` where `index_j` is the `k-th` value of - the feature in the example at with key `SparseFeature.index_key[j]. + the feature in the example at with key `SparseFeature.index_key[j]`. In other words, we split the indices (except the first index indicating the batch entry) of a `SparseTensor` by dimension into different features of the `Example`. Due to its complexity a `VarLenFeature` should be preferred over a |