aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc430
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc11
-rw-r--r--tensorflow/core/kernels/constant_op.cc38
-rw-r--r--tensorflow/core/ops/array_ops.cc11
-rw-r--r--tensorflow/core/ops/array_ops_test.cc1
-rw-r--r--tensorflow/python/eager/ops_test.py9
-rw-r--r--tensorflow/python/framework/constant_op.py3
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py72
-rw-r--r--tensorflow/python/ops/parsing_ops.py4
12 files changed, 344 insertions, 240 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 6e7558c00e..e517159721 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -225,6 +225,8 @@ bool IsRestore(const NodeDef& node) {
node.op() == "RestoreSlice");
}
+bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
+
bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index d9dc147a39..28d6ce52e1 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -89,6 +89,7 @@ bool IsRecv(const NodeDef& node);
bool IsReduction(const NodeDef& node);
bool IsReshape(const NodeDef& node);
bool IsRestore(const NodeDef& node);
+bool IsReverseV2(const NodeDef& node);
bool IsRsqrtGrad(const NodeDef& node);
bool IsSelect(const NodeDef& node);
bool IsSeluGrad(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 59df49c245..62de52316e 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1486,7 +1486,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// TODO(rmlarsen): Handle non-associative/non-commutative operators like
// subtraction and division, as well as mixed subtraction/addition,
// division/multiplication.
- if ((is_add || is_mul) && NumNonControlInputs(*node) == 2) {
+ if ((IsAdd(*node) || is_mul) && NumNonControlInputs(*node) == 2) {
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
// One child must be constant, and the other the same op as the parent.
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 6bedcc4b36..3b52728b3b 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -41,6 +41,8 @@ const char kPermNHWCToNCHW[] = "LayoutOptimizerPermConstNHWCToNCHW";
const char kPermNCHWToNHWC[] = "LayoutOptimizerPermConstNCHWToNHWC";
const char kTransposeNHWCToNCHW[] = "LayoutOptimizerTransposeNHWCToNCHW";
const char kTransposeNCHWToNHWC[] = "LayoutOptimizerTransposeNCHWToNHWC";
+const char kDimMapNHWCToNCHW[] = "LayoutOptimizerDimMapNHWCToNCHW";
+const char kDimMapNCHWToNHWC[] = "LayoutOptimizerDimMapNCHWToNHWC";
const char kVecPermuteNHWCToNCHW[] = "LayoutOptimizerVecPermuteNHWCToNCHW";
const char kVecPermuteNCHWToNHWC[] = "LayoutOptimizerVecPermuteNCHWToNHWC";
const char kReshapeNHWCToNCHW[] = "LayoutOptimizerReshapeNHWCToNCHW";
@@ -179,6 +181,7 @@ std::set<string> GetOpsFormatAgnostic() {
"Switch",
"TruncateDiv",
"TruncateMod",
+ "ReverseV2",
"Round",
"Rsqrt",
"RsqrtGrad",
@@ -453,12 +456,7 @@ class NodeProcessor : public GraphProcessor {
return nodes_to_preserve_.find(node_->name()) != nodes_to_preserve_.end();
}
- virtual bool ShouldProcess() const {
- return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
- HasOutputs() && IsOnGPU();
- }
-
- virtual bool IsOnGPU() const {
+ bool IsOnGPU() const {
string device_name;
if (node_->device().empty()) {
device_name = virtual_placer_.get_canonical_device_name(*node_);
@@ -475,14 +473,9 @@ class NodeProcessor : public GraphProcessor {
return false;
}
- void UpdateAttrDataFormat() {
- if (node_->attr().find("data_format") != node_->attr().end()) {
- if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
- string* data_format =
- node_->mutable_attr()->at("data_format").mutable_s();
- *data_format = "NCHW";
- }
- }
+ virtual bool ShouldProcess() const {
+ return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
+ HasOutputs() && IsOnGPU();
}
virtual void UpdateAttrShape() {
@@ -504,74 +497,7 @@ class NodeProcessor : public GraphProcessor {
}
}
- void UpdateAttrKSize() {
- if (node_->attr().find("ksize") != node_->attr().end()) {
- auto list = node_->mutable_attr()->at("ksize").mutable_list();
- UpdateTuple(list);
- }
- }
-
- void UpdateAttrStrides() {
- if (node_->attr().find("strides") != node_->attr().end()) {
- auto list = node_->mutable_attr()->at("strides").mutable_list();
- UpdateTuple(list);
- }
- }
-
- Status UpdateAttrValue(NodeDef* node) {
- TF_RETURN_IF_ERROR(HasAttribute(*node, "value"));
- Tensor tensor;
- auto success =
- tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
- if (!success) {
- LOG(ERROR) << "Failed to parse TensorProto.";
- }
- if (tensor.dims() == 0) {
- int value = tensor.scalar<int>()();
- value = (value >= 0) ? value : value + 4;
- if (value == 1 || value == 2) {
- value = value + 1;
- } else if (value == 3) {
- value = 1;
- }
- tensor.scalar<int>()() = value;
- } else if (tensor.dims() == 1) {
- if (tensor.flat<int>().size() == 4) {
- int c = tensor.flat<int>()(3);
- tensor.flat<int>()(3) = tensor.flat<int>()(2);
- tensor.flat<int>()(2) = tensor.flat<int>()(1);
- tensor.flat<int>()(1) = c;
- } else if (tensor.flat<int>().size() == 3) {
- tensor.flat<int>()(0) = 0;
- tensor.flat<int>()(1) = 2;
- tensor.flat<int>()(2) = 3;
- } else {
- return Status(error::INVALID_ARGUMENT,
- strings::StrCat("Unsupported tensor size: ",
- tensor.flat<int>().size()));
- }
- } else if (tensor.dims() == 2) {
- for (int i = 0; i < 2; i++) {
- int c = tensor.matrix<int>()(3, i);
- tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i);
- tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i);
- tensor.matrix<int>()(1, i) = c;
- }
- } else {
- return Status(
- error::INVALID_ARGUMENT,
- strings::StrCat("Unsupported dimension size: ", tensor.dims()));
- }
- if (tensor.dims() == 0) {
- tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor());
- } else {
- tensor.AsProtoTensorContent(
- node->mutable_attr()->at({"value"}).mutable_tensor());
- }
- return Status::OK();
- }
-
- Status UpdateAttrValueOfInput(int input_index) {
+ Status UpdateAttrValueOfInput(int input_index, bool permute) {
auto input_node = node_map_->GetNode(node_->input(input_index));
// We created a copy of the node, so that we don't modify the original node,
// which might be used elsewhere. Note that this copy also copies the
@@ -585,7 +511,7 @@ class NodeProcessor : public GraphProcessor {
*node_->mutable_input(input_index) = node_name;
node_map_->AddNode(node_name, added_node);
node_map_->AddOutput(node_name, node_->name());
- return UpdateAttrValue(added_node);
+ return UpdateAttrValue(added_node, permute);
}
virtual std::vector<int> GetInputPos() const {
@@ -601,42 +527,6 @@ class NodeProcessor : public GraphProcessor {
return output_pos;
}
- NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
- const string& const_name, DataType data_type,
- const TensorShapeProto& input_shape,
- bool NHWCToNCHW) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(node_name, node);
- node->set_name(node_name);
- *node->add_input() = input_name;
- *node->add_input() = const_name;
- node->set_op("Transpose");
- node->set_device(node_->device());
- AttrValue attr_data_type;
- attr_data_type.set_type(data_type);
- node->mutable_attr()->insert({"T", attr_data_type});
- AttrValue attr_data_type_perm;
- attr_data_type_perm.set_type(DT_INT32);
- node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
- if (!input_shape.unknown_rank()) {
- AttrValue attr_output_shape;
- auto output_shape = attr_output_shape.mutable_list()->add_shape();
- if (NHWCToNCHW) {
- output_shape->add_dim()->set_size(input_shape.dim(0).size());
- output_shape->add_dim()->set_size(input_shape.dim(3).size());
- output_shape->add_dim()->set_size(input_shape.dim(1).size());
- output_shape->add_dim()->set_size(input_shape.dim(2).size());
- } else {
- output_shape->add_dim()->set_size(input_shape.dim(0).size());
- output_shape->add_dim()->set_size(input_shape.dim(2).size());
- output_shape->add_dim()->set_size(input_shape.dim(3).size());
- output_shape->add_dim()->set_size(input_shape.dim(1).size());
- }
- node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
- }
- return node;
- }
-
virtual Status AddLayoutTransposeToInputs() {
std::vector<int> input_pos = GetInputPos();
for (const auto& pos : input_pos) {
@@ -732,6 +622,137 @@ class NodeProcessor : public GraphProcessor {
virtual Status CustomizedProcessing() { return Status::OK(); }
+ Status UpdateOrTransformParamInput(int param_index, const string& op,
+ DataType dtype) {
+ auto param_node = node_map_->GetNode(node_->input(param_index));
+ bool permute = (op == "DataFormatVecPermute") ? true : false;
+ if (IsConstant(*param_node)) {
+ TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(param_index, permute));
+ } else {
+ AddDataFormatTranformToParamInput(op, param_index, dtype);
+ }
+ return Status::OK();
+ }
+
+ NodeDef* node_;
+ bool is_in_frame_;
+
+ private:
+ void UpdateAttrKSize() {
+ if (node_->attr().find("ksize") != node_->attr().end()) {
+ auto list = node_->mutable_attr()->at("ksize").mutable_list();
+ UpdateTuple(list);
+ }
+ }
+
+ void UpdateAttrStrides() {
+ if (node_->attr().find("strides") != node_->attr().end()) {
+ auto list = node_->mutable_attr()->at("strides").mutable_list();
+ UpdateTuple(list);
+ }
+ }
+
+ void UpdateAttrDataFormat() {
+ if (node_->attr().find("data_format") != node_->attr().end()) {
+ if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
+ string* data_format =
+ node_->mutable_attr()->at("data_format").mutable_s();
+ *data_format = "NCHW";
+ }
+ }
+ }
+
+ Status UpdateAttrValue(NodeDef* node, bool permute) {
+ TF_RETURN_IF_ERROR(HasAttribute(*node, "value"));
+ Tensor tensor;
+ auto success =
+ tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
+ if (!success) {
+ LOG(ERROR) << "Failed to parse TensorProto.";
+ }
+
+ if (permute) {
+ if (tensor.dims() == 1) {
+ if (tensor.flat<int>().size() == 4) {
+ int c = tensor.flat<int>()(3);
+ tensor.flat<int>()(3) = tensor.flat<int>()(2);
+ tensor.flat<int>()(2) = tensor.flat<int>()(1);
+ tensor.flat<int>()(1) = c;
+ } else {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("Unsupported tensor size: ",
+ tensor.flat<int>().size()));
+ }
+ } else if (tensor.dims() == 2) {
+ for (int i = 0; i < 2; i++) {
+ int c = tensor.matrix<int>()(3, i);
+ tensor.matrix<int>()(3, i) = tensor.matrix<int>()(2, i);
+ tensor.matrix<int>()(2, i) = tensor.matrix<int>()(1, i);
+ tensor.matrix<int>()(1, i) = c;
+ }
+ } else {
+ return Status(
+ error::INVALID_ARGUMENT,
+ strings::StrCat("Unsupported dimension size: ", tensor.dims()));
+ }
+ } else {
+ for (int i = 0; i < tensor.flat<int>().size(); i++) {
+ int value = tensor.flat<int>()(i);
+ value = (value >= 0) ? value : value + 4;
+ if (value == 1 || value == 2) {
+ value = value + 1;
+ } else if (value == 3) {
+ value = 1;
+ }
+ tensor.flat<int>()(i) = value;
+ }
+ }
+
+ if (tensor.dims() == 0) {
+ tensor.AsProtoField(node->mutable_attr()->at({"value"}).mutable_tensor());
+ } else {
+ tensor.AsProtoTensorContent(
+ node->mutable_attr()->at({"value"}).mutable_tensor());
+ }
+ return Status::OK();
+ }
+
+ NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
+ const string& const_name, DataType data_type,
+ const TensorShapeProto& input_shape,
+ bool NHWCToNCHW) {
+ NodeDef* node = graph_->add_node();
+ node_map_->AddNode(node_name, node);
+ node->set_name(node_name);
+ *node->add_input() = input_name;
+ *node->add_input() = const_name;
+ node->set_op("Transpose");
+ node->set_device(node_->device());
+ AttrValue attr_data_type;
+ attr_data_type.set_type(data_type);
+ node->mutable_attr()->insert({"T", attr_data_type});
+ AttrValue attr_data_type_perm;
+ attr_data_type_perm.set_type(DT_INT32);
+ node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
+ if (!input_shape.unknown_rank()) {
+ AttrValue attr_output_shape;
+ auto output_shape = attr_output_shape.mutable_list()->add_shape();
+ if (NHWCToNCHW) {
+ output_shape->add_dim()->set_size(input_shape.dim(0).size());
+ output_shape->add_dim()->set_size(input_shape.dim(3).size());
+ output_shape->add_dim()->set_size(input_shape.dim(1).size());
+ output_shape->add_dim()->set_size(input_shape.dim(2).size());
+ } else {
+ output_shape->add_dim()->set_size(input_shape.dim(0).size());
+ output_shape->add_dim()->set_size(input_shape.dim(2).size());
+ output_shape->add_dim()->set_size(input_shape.dim(3).size());
+ output_shape->add_dim()->set_size(input_shape.dim(1).size());
+ }
+ node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
+ }
+ return node;
+ }
+
NodeDef* AddNodePermNHWCToNCHW(const string& suffix,
const string& depended_node,
const string& device) {
@@ -754,44 +775,6 @@ class NodeProcessor : public GraphProcessor {
return const_node;
}
- NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
- const string& op, DataType dtype,
- bool nhwc_to_nchw) {
- NodeDef* added_node = graph_->add_node();
- added_node->set_name(name);
- added_node->set_op(op);
- node_map_->AddNode(added_node->name(), added_node);
- added_node->set_device(node_->device());
- AttrValue attr_data_type;
- attr_data_type.set_type(dtype);
- added_node->mutable_attr()->insert({"T", attr_data_type});
- string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW";
- string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC";
- AttrValue attr_format;
- attr_format.set_s(src_format);
- added_node->mutable_attr()->insert({"src_format", attr_format});
- attr_format.set_s(dst_format);
- added_node->mutable_attr()->insert({"dst_format", attr_format});
- *added_node->add_input() = input_name;
- return added_node;
- }
-
- void AddDataFormatTranformToInput(const string& op, int input_pos,
- DataType dtype) {
- string name = strings::StrCat(kVecPermuteNHWCToNCHW, "_", node_->name(),
- "_", input_pos);
- auto added_node =
- AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
- *node_->mutable_input(input_pos) = added_node->name();
- node_map_->UpdateOutput(added_node->input(0), node_->name(),
- added_node->name());
- node_map_->AddOutput(added_node->name(), node_->name());
- }
-
- NodeDef* node_;
- bool is_in_frame_;
-
- private:
string GetOrAddNodePermNHWCToNCHW(int pos) {
string const_name;
if (is_in_frame_) {
@@ -833,6 +816,41 @@ class NodeProcessor : public GraphProcessor {
list->set_i(2, h);
list->set_i(3, w);
}
+
+ NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
+ const string& op, DataType dtype,
+ bool nhwc_to_nchw) {
+ NodeDef* added_node = graph_->add_node();
+ added_node->set_name(name);
+ added_node->set_op(op);
+ node_map_->AddNode(added_node->name(), added_node);
+ added_node->set_device(node_->device());
+ AttrValue attr_data_type;
+ attr_data_type.set_type(dtype);
+ added_node->mutable_attr()->insert({"T", attr_data_type});
+ string src_format = (nhwc_to_nchw) ? "NHWC" : "NCHW";
+ string dst_format = (nhwc_to_nchw) ? "NCHW" : "NHWC";
+ AttrValue attr_format;
+ attr_format.set_s(src_format);
+ added_node->mutable_attr()->insert({"src_format", attr_format});
+ attr_format.set_s(dst_format);
+ added_node->mutable_attr()->insert({"dst_format", attr_format});
+ *added_node->add_input() = input_name;
+ return added_node;
+ }
+
+ void AddDataFormatTranformToParamInput(const string& op, int input_pos,
+ DataType dtype) {
+ string prefix = (op == "DataFormatVecPermute") ? kVecPermuteNHWCToNCHW
+ : kDimMapNHWCToNCHW;
+ string name = strings::StrCat(prefix, "_", node_->name(), "_", input_pos);
+ auto added_node =
+ AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
+ *node_->mutable_input(input_pos) = added_node->name();
+ node_map_->UpdateOutput(added_node->input(0), node_->name(),
+ added_node->name());
+ node_map_->AddOutput(added_node->name(), node_->name());
+ }
};
class AvgPoolGradProcessor : public NodeProcessor {
@@ -845,7 +863,9 @@ class AvgPoolGradProcessor : public NodeProcessor {
std::vector<int> input_pos = {1};
return input_pos;
}
- Status CustomizedProcessing() override { return UpdateAttrValueOfInput(0); }
+ Status CustomizedProcessing() override {
+ return UpdateAttrValueOfInput(0, true);
+ }
};
class BiasAddGradProcessor : public NodeProcessor {
@@ -986,12 +1006,8 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor {
}
Status CustomizedProcessing() override {
- auto input_size_node = node_map_->GetNode(node_->input(0));
- if (IsConstant(*input_size_node)) {
- TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(0));
- } else {
- AddDataFormatTranformToInput("DataFormatVecPermute", 0, DT_INT32);
- }
+ TF_RETURN_IF_ERROR(
+ UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32));
return Status::OK();
}
};
@@ -1042,12 +1058,8 @@ class MaxPoolGradV2Processor : public MaxPoolGradProcessor {
protected:
Status CustomizedProcessing() override {
for (int i = 3; i < node_->input_size(); i++) {
- auto param_node = node_map_->GetNode(node_->input(i));
- if (IsConstant(*param_node)) {
- TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(i));
- } else {
- AddDataFormatTranformToInput("DataFormatVecPermute", i, DT_INT32);
- }
+ TF_RETURN_IF_ERROR(
+ UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
}
return Status::OK();
}
@@ -1072,12 +1084,8 @@ class MaxPoolV2Processor : public NodeProcessor {
Status CustomizedProcessing() override {
for (int i = 1; i < node_->input_size(); i++) {
- auto param_node = node_map_->GetNode(node_->input(i));
- if (IsConstant(*param_node)) {
- TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(i));
- } else {
- AddDataFormatTranformToInput("DataFormatVecPermute", i, DT_INT32);
- }
+ TF_RETURN_IF_ERROR(
+ UpdateOrTransformParamInput(i, "DataFormatVecPermute", DT_INT32));
}
return Status::OK();
}
@@ -1325,17 +1333,12 @@ class ConcatProcessor : public AgnosticNodeProcessor {
}
Status CustomizedProcessing() override {
- auto dim_node = node_map_->GetNode(node_->input(axis_node_pos_));
- if (IsConstant(*dim_node)) {
- TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(axis_node_pos_));
- } else {
- DataType dtype = (IsSplit(*node_) || IsSplitV(*node_))
- ? DT_INT32
- : node_->attr().at("Tidx").type();
- AddDataFormatTranformToInput("DataFormatDimMap", axis_node_pos_, dtype);
- }
+ DataType dtype = node_->attr().at("Tidx").type();
+ TF_RETURN_IF_ERROR(
+ UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap", dtype));
return Status::OK();
}
+
int axis_node_pos_;
};
@@ -1414,29 +1417,36 @@ class PadProcessor : public AgnosticNodeProcessor {
protected:
Status CustomizedProcessing() override {
- auto index_node = node_map_->GetNode(node_->input(1));
- if (IsConstant(*index_node)) {
- TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(1));
- } else {
- DataType dtype = node_->attr().at("Tpaddings").type();
- AddDataFormatTranformToInput("DataFormatVecPermute", 1, dtype);
- }
+ DataType dtype = node_->attr().at("Tpaddings").type();
+ TF_RETURN_IF_ERROR(
+ UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype));
return Status::OK();
}
};
-class SplitProcessor : public ConcatProcessor {
+class ReverseProcessor : public AgnosticNodeProcessor {
+ public:
+ explicit ReverseProcessor(const OptimizeContext& opt_cxt)
+ : AgnosticNodeProcessor(opt_cxt) {}
+
+ protected:
+ Status CustomizedProcessing() override {
+ DataType dtype = node_->attr().at("Tidx").type();
+ TF_RETURN_IF_ERROR(
+ UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
+ return Status::OK();
+ }
+};
+
+class SplitProcessor : public AgnosticNodeProcessor {
public:
explicit SplitProcessor(const OptimizeContext& opt_cxt)
- : ConcatProcessor(opt_cxt) {
+ : AgnosticNodeProcessor(opt_cxt) {
axis_node_pos_ = 0;
}
protected:
- std::vector<int> GetInputPos() const override {
- std::vector<int> input_pos = {1};
- return input_pos;
- }
+ std::vector<int> GetInputPos() const override { return {1}; }
std::set<int> GetOutputPos() const override {
std::set<int> output_pos{0};
@@ -1447,6 +1457,14 @@ class SplitProcessor : public ConcatProcessor {
}
return output_pos;
}
+
+ Status CustomizedProcessing() override {
+ TF_RETURN_IF_ERROR(UpdateOrTransformParamInput(
+ axis_node_pos_, "DataFormatDimMap", DT_INT32));
+ return Status::OK();
+ }
+
+ int axis_node_pos_;
};
class SplitVProcessor : public SplitProcessor {
@@ -1533,13 +1551,9 @@ class SliceProcessor : public AgnosticNodeProcessor {
Status CustomizedProcessing() override {
// Skip the first input, which is the data to be sliced.
for (int i = 1; i < node_->input_size(); i++) {
- auto index_node = node_map_->GetNode(node_->input(i));
- if (IsConstant(*index_node)) {
- TF_RETURN_IF_ERROR(UpdateAttrValueOfInput(i));
- } else {
- AddDataFormatTranformToInput("DataFormatVecPermute", i,
- node_->attr().at("Index").type());
- }
+ DataType dtype = node_->attr().at("Index").type();
+ TF_RETURN_IF_ERROR(
+ UpdateOrTransformParamInput(i, "DataFormatVecPermute", dtype));
}
return Status::OK();
}
@@ -1614,7 +1628,9 @@ class SumProcessor : public AgnosticNodeProcessor {
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
- Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); }
+ Status CustomizedProcessing() override {
+ return UpdateAttrValueOfInput(1, false);
+ }
private:
bool IsAlongDimNHW() const {
@@ -1765,6 +1781,8 @@ class DataLayoutOptimizer : GraphProcessor {
node_processor.reset(new MergeProcessor(opt_cxt));
} else if (IsPad(*node)) {
node_processor.reset(new PadProcessor(opt_cxt));
+ } else if (IsReverseV2(*node)) {
+ node_processor.reset(new ReverseProcessor(opt_cxt));
} else if (IsSlice(*node)) {
node_processor.reset(new SliceProcessor(opt_cxt));
} else if (IsShape(*node) || IsShapeN(*node)) {
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index 781f942532..161cf7ad96 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -559,11 +559,9 @@ TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0),
- "LayoutOptimizerVecPermuteNHWCToNCHW_split_0");
+ EXPECT_EQ(split_node->input(0), "LayoutOptimizerDimMapNHWCToNCHW_split_0");
EXPECT_EQ(split_node->input(1), "Conv2D");
- auto map_node =
- node_map.GetNode("LayoutOptimizerVecPermuteNHWCToNCHW_split_0");
+ auto map_node = node_map.GetNode("LayoutOptimizerDimMapNHWCToNCHW_split_0");
EXPECT_EQ(map_node->op(), "DataFormatDimMap");
EXPECT_EQ(map_node->input(0), "i1");
}
@@ -629,10 +627,9 @@ TEST_F(LayoutOptimizerTest, ConcatNonConst) {
auto concat_node = node_map.GetNode("concat");
EXPECT_EQ(concat_node->input(0), "split");
EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2),
- "LayoutOptimizerVecPermuteNHWCToNCHW_concat_2");
+ EXPECT_EQ(concat_node->input(2), "LayoutOptimizerDimMapNHWCToNCHW_concat_2");
auto concat_dim =
- node_map.GetNode("LayoutOptimizerVecPermuteNHWCToNCHW_concat_2");
+ node_map.GetNode("LayoutOptimizerDimMapNHWCToNCHW_concat_2");
EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
EXPECT_EQ(concat_dim->input(0), "i");
}
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 103a0e225e..c485d02e23 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -164,25 +164,24 @@ struct FillFunctor<CPUDevice, T> {
} // end namespace functor
-template <typename Device, typename T>
+template <typename Device, typename T, typename Index>
class FillOp : public OpKernel {
public:
explicit FillOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& Tdims = context->input(0);
- OP_REQUIRES(
- context, IsLegacyVector(Tdims.shape()),
- errors::InvalidArgument("dims must be a vector of int32, got shape ",
- Tdims.shape().DebugString()));
+ OP_REQUIRES(context, IsLegacyVector(Tdims.shape()),
+ errors::InvalidArgument("dims must be a vector, got shape ",
+ Tdims.shape().DebugString()));
const Tensor& Tvalue = context->input(1);
OP_REQUIRES(context, IsLegacyScalar(Tvalue.shape()),
errors::InvalidArgument("value must be a scalar, got shape ",
Tvalue.shape().DebugString()));
- auto dims = Tdims.flat<int32>();
+ auto dims = Tdims.flat<Index>();
TensorShape shape;
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
- reinterpret_cast<const int32*>(dims.data()),
+ reinterpret_cast<const Index*>(dims.data()),
dims.size(), &shape));
Tensor* out = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &out));
@@ -214,12 +213,19 @@ struct FillFunctor<SYCLDevice, T> {
} // namespace functor
#endif // TENSORFLOW_USE_SYCL
-#define REGISTER_KERNEL(D, TYPE) \
- REGISTER_KERNEL_BUILDER(Name("Fill") \
- .Device(DEVICE_##D) \
- .TypeConstraint<TYPE>("T") \
- .HostMemory("dims"), \
- FillOp<D##Device, TYPE>);
+#define REGISTER_KERNEL(D, TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Fill") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint<int32>("index_type") \
+ .HostMemory("dims"), \
+ FillOp<D##Device, TYPE, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Fill") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint<int64>("index_type") \
+ .HostMemory("dims"), \
+ FillOp<D##Device, TYPE, int64>);
#define REGISTER_CPU_KERNEL(TYPE) REGISTER_KERNEL(CPU, TYPE)
TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
@@ -241,10 +247,11 @@ REGISTER_KERNEL(SYCL, int64);
REGISTER_KERNEL_BUILDER(Name("Fill")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("index_type")
.HostMemory("dims")
.HostMemory("value")
.HostMemory("output"),
- FillOp<CPUDevice, int32>);
+ FillOp<CPUDevice, int32, int32>);
#undef REGISTER_KERNEL_SYCL
#endif // TENSORFLOW_USE_SYCL
@@ -267,10 +274,11 @@ REGISTER_KERNEL(GPU, bool);
REGISTER_KERNEL_BUILDER(Name("Fill")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("index_type")
.HostMemory("dims")
.HostMemory("value")
.HostMemory("output"),
- FillOp<CPUDevice, int32>);
+ FillOp<CPUDevice, int32, int32>);
#endif
#undef REGISTER_KERNEL
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 5a31f433ce..c7d9b97461 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1328,11 +1328,17 @@ The output will be:
// --------------------------------------------------------------------------
REGISTER_OP("Fill")
- .Input("dims: int32")
+ .Input("dims: index_type")
.Input("value: T")
.Output("output: T")
.Attr("T: type")
+ .Attr("index_type: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
+ DataType index_type = DT_INT32;
+ Status s = c->GetAttr("index_type", &index_type);
+ if (!s.ok() && s.code() != error::NOT_FOUND) {
+ return s;
+ }
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -1340,7 +1346,8 @@ REGISTER_OP("Fill")
const Tensor* t = c->input_tensor(0);
if (t != nullptr) {
for (int i = 0; i < t->NumElements(); ++i) {
- if (t->vec<int32>()(i) < 0) {
+ if ((index_type == DT_INT32 && t->vec<int32>()(i) < 0) ||
+ (index_type == DT_INT64 && t->vec<int64>()(i) < 0)) {
return errors::InvalidArgument("Fill dimensions must be >= 0");
}
}
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index c8ea443613..a182fd1c47 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -253,6 +253,7 @@ TEST(ArrayOpsTest, ReverseV2_ShapeFn) {
TEST(ArrayOpsTest, Fill_ShapeFn) {
ShapeInferenceTestOp op("Fill");
+ AddNodeAttr("index_type", DT_INT32, &op.node_def);
op.input_tensors.resize(2);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[?];?", "?");
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py
index 48dcb4830c..f8c5037dcf 100644
--- a/tensorflow/python/eager/ops_test.py
+++ b/tensorflow/python/eager/ops_test.py
@@ -255,11 +255,12 @@ class OpsTest(test_util.TensorFlowTestCase):
'using.*DEVICE_PLACEMENT_SILENT'):
reshaped = array_ops.reshape(value, shape.gpu())
- def testInvalidInputDataType(self):
+ def testInt64(self):
# Fill requires the first input to be an int32 tensor.
- with self.assertRaisesRegexp(errors.InvalidArgumentError, 'int64'):
- array_ops.fill(constant_op.constant([2], dtype=dtypes.int64),
- constant_op.constant(1))
+ self.assertAllEqual(
+ [1.0, 1.0],
+ array_ops.fill(constant_op.constant([2], dtype=dtypes.int64),
+ constant_op.constant(1)))
def testOutputOnHostMemory(self):
if not context.context().num_gpus():
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index bf3be34d85..ac915157f5 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -45,6 +45,7 @@ import numpy as np
import six
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import types_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.framework import dtypes
@@ -71,7 +72,7 @@ def _eager_fill(dims, value, ctx):
attr_t = value.dtype.as_datatype_enum
dims = convert_to_eager_tensor(dims, ctx, dtypes.int32)
inputs_flat = [dims, value]
- attrs = ("T", attr_t)
+ attrs = ("T", attr_t, "index_type", types_pb2.DT_INT32)
result, = execute.execute(
b"Fill", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx)
return result
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 2c15e340c9..909981982d 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -279,7 +279,7 @@ class LayoutOptimizerTest(test.TestCase):
self.assertEqual(expected_num_transposes, num_transposes)
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-split-0-0', nodes)
- self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_split_0', nodes)
+ self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_split_0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testSplitVWithNonConstAxis(self):
@@ -313,7 +313,7 @@ class LayoutOptimizerTest(test.TestCase):
self.assertEqual(expected_num_transposes, num_transposes)
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-SplitV-0-0', nodes)
- self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_SplitV_2', nodes)
+ self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_SplitV_2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testPadWithConstPaddings(self):
@@ -350,6 +350,74 @@ class LayoutOptimizerTest(test.TestCase):
self.assertIn('LayoutOptimizer-Pad-PaddingsConst', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+ def testReverseWithConstDims(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ dims = constant_op.constant([3, 1], name='DimsConst')
+ reverse = array_ops.reverse(conv, dims)
+ output = array_ops.identity(reverse)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if node.name.startswith('LayoutOptimizerTranspose'):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Four transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 2
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+ self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-ReverseV2-0-0', nodes)
+ self.assertIn('LayoutOptimizer-ReverseV2-DimsConst', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
+ def testReverseWithNonConstDims(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ dims = array_ops.placeholder(dtype='int32')
+ reverse = array_ops.reverse(conv, dims)
+ output = array_ops.identity(reverse)
+
+ dims_val = [2, 3]
+ with session.Session() as sess:
+ output_val_ref = sess.run(output, feed_dict={dims: dims_val})
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(
+ output, run_metadata=metadata, feed_dict={
+ dims: dims_val
+ })
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if node.name.startswith('LayoutOptimizerTranspose'):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Four transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 2
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+ self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-ReverseV2-0-0', nodes)
+ self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_ReverseV2_1', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
def testTernaryOp(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index eba40c4f85..df1f00ddcd 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -385,7 +385,7 @@ def parse_example(serialized, features, name=None, example_names=None):
A `values[i]` comes from a position `k` in the feature of an example at batch
entry `batch`. This positional information is recorded in `indices[i]` as
`[batch, index_0, index_1, ...]` where `index_j` is the `k-th` value of
- the feature in the example at with key `SparseFeature.index_key[j].
+ the feature in the example at with key `SparseFeature.index_key[j]`.
In other words, we split the indices (except the first index indicating the
batch entry) of a `SparseTensor` by dimension into different features of the
`Example`. Due to its complexity a `VarLenFeature` should be preferred over a
@@ -1234,7 +1234,7 @@ def parse_single_example_v2(serialized, features, name=None):
A `values[i]` comes from a position `k` in the feature of an example at batch
entry `batch`. This positional information is recorded in `indices[i]` as
`[batch, index_0, index_1, ...]` where `index_j` is the `k-th` value of
- the feature in the example at with key `SparseFeature.index_key[j].
+ the feature in the example at with key `SparseFeature.index_key[j]`.
In other words, we split the indices (except the first index indicating the
batch entry) of a `SparseTensor` by dimension into different features of the
`Example`. Due to its complexity a `VarLenFeature` should be preferred over a