aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-14 11:42:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 11:45:56 -0700
commit39f50af5634b8a4d2132b57bad2152308a0fd41c (patch)
tree5a5d0b0a9722067b702995dc84a1c4d8156d36a4
parentc20a7b81d79d30db9e990309ddb419bcb48120cc (diff)
Improve output parsing for unsupported ops
PiperOrigin-RevId: 213017532
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc82
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc52
2 files changed, 104 insertions, 30 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index efc1007925..2ccfd36b7c 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -69,6 +69,13 @@ bool HasAttr(const NodeDef& node, const string& attr_name) {
return node.attr().count(attr_name) > 0;
}
+bool HasWildcardDimension(const TensorShapeProto& shape) {
+ for (const auto& dim : shape.dim()) {
+ if (dim.size() == -1) return true;
+ }
+ return false;
+}
+
const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
CHECK(HasAttr(node, attr_name));
const auto& attr = node.attr().at(attr_name);
@@ -1054,15 +1061,27 @@ tensorflow::Status ConvertUnsupportedOperator(
"_support_output_type_float_in_quantized_op";
LOG(INFO) << "Converting unsupported operation: " << node.op();
+
auto* op = new TensorFlowUnsupportedOperator;
+ op->tensorflow_op = node.op();
+ node.SerializeToString(&op->tensorflow_node_def);
+ model->operators.emplace_back(op);
+
+ // Parse inputs.
const int num_inputs = GetInputsCount(node, tf_import_flags);
for (int i = 0; i < num_inputs; ++i) {
op->inputs.push_back(node.input(i));
}
- op->outputs.push_back(node.name());
- op->tensorflow_op = node.op();
- node.SerializeToString(&op->tensorflow_node_def);
- model->operators.emplace_back(op);
+
+ // Parse outputs.
+ op->outputs.push_back(node.name()); // Implicit :0.
+ const tensorflow::OpDef* op_def = nullptr;
+ if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+ for (int i = 1; i < op_def->output_arg_size(); ++i) {
+ op->outputs.push_back(absl::StrCat(node.name(), ":", i));
+ }
+ }
+
// Parse if the op supports quantization
if (HasAttr(node, kAttrOutputQuantized)) {
op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
@@ -1072,6 +1091,8 @@ tensorflow::Status ConvertUnsupportedOperator(
op->support_output_type_float_in_quantized_op =
GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
}
+
+ // Parse output type(s).
if (HasAttr(node, kAttrOutputTypes)) {
const auto& output_types = GetListAttr(node, kAttrOutputTypes);
for (int i = 0; i < output_types.type_size(); ++i) {
@@ -1080,33 +1101,40 @@ tensorflow::Status ConvertUnsupportedOperator(
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
- } else {
- const tensorflow::OpDef* op_def = nullptr;
- if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
- for (const auto& output_arg : op_def->output_arg()) {
- if (HasAttr(node, output_arg.type_attr())) {
- op->output_data_types.push_back(
- ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
- } else {
- LOG(INFO) << "Op node missing output type attribute: " << node.name();
- op->output_data_types.clear();
- break;
- }
+ } else if (op_def != nullptr) {
+ for (const auto& output_arg : op_def->output_arg()) {
+ if (HasAttr(node, output_arg.type_attr())) {
+ op->output_data_types.push_back(
+ ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
+ } else {
+ LOG(INFO) << "Op node missing output type attribute: " << node.name();
+ op->output_data_types.clear();
+ break;
}
}
- if (op->output_data_types.empty()) {
- // TODO(b/113613439): Figure out how to propagate types for custom ops
- // that have no OpDef.
- LOG(INFO) << "Unable to determine output type for op: " << node.op();
- }
+ } else {
+ // TODO(b/113613439): Figure out how to propagate types for custom ops
+ // that have no OpDef.
+ LOG(INFO) << "Unable to determine output type for op: " << node.op();
}
+
+ // Parse output shape(s).
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
Shape output_shape;
for (int i = 0; i < output_shapes.shape_size(); ++i) {
+ const auto& shape = output_shapes.shape(i);
+ // TOCO doesn't yet properly handle shapes with wildcard dimensions.
+ // TODO(b/113613439): Handle shape inference for unsupported ops that have
+ // shapes with wildcard dimensions.
+ if (HasWildcardDimension(shape)) {
+ LOG(INFO) << "Skipping wildcard output shape(s) for node: "
+ << node.name();
+ op->output_shapes.clear();
+ break;
+ }
const auto status =
- ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr,
- &output_shape);
+ ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
if (!status.ok()) {
return status;
}
@@ -1159,15 +1187,9 @@ tensorflow::Status ConvertPlaceholderOperator(
if (node.attr().count("shape")) {
const auto& shape = GetShapeAttr(node, "shape");
auto num_dims = shape.dim_size();
- bool has_wildcard = false;
- for (std::size_t i = 0; i < num_dims; i++) {
- if (shape.dim(i).size() == -1) {
- has_wildcard = true;
- }
- }
// TODO(b/62716978): This logic needs to be revisted. During dims
// refactoring it is an interim fix.
- if (num_dims > 0 && !has_wildcard) {
+ if (num_dims > 0 && !HasWildcardDimension(shape)) {
auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
dst_array_dims.resize(num_dims);
for (std::size_t i = 0; i < num_dims; i++) {
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index da248826a7..8a236d4444 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -60,6 +60,28 @@ Status ImportNode(const NodeDef& node) {
return ImportNode(node, &model);
}
+NodeDef BuildNode(
+ const std::string& op,
+ const std::vector<std::initializer_list<int>>& output_shapes) {
+ NodeDef node;
+ node.set_op(op);
+ node.set_name("Node1");
+ node.add_input();
+ node.set_input(0, "Node0");
+
+ AttrValue::ListValue* shapes =
+ (*node.mutable_attr())["_output_shapes"].mutable_list();
+ for (const auto& output_shape : output_shapes) {
+ tensorflow::TensorShapeProto* shape = shapes->add_shape();
+ for (int64_t output_shape_dim : output_shape) {
+ auto shape_dim = shape->add_dim();
+ shape_dim->set_size(output_shape_dim);
+ }
+ }
+
+ return node;
+}
+
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
protected:
ShapeImportTest() {}
@@ -232,5 +254,35 @@ TEST(ImportTest, FailedTypeInference) {
ASSERT_TRUE(op->output_data_types.empty());
}
+TEST(ImportTest, UnsupportedOpWithOutputShapes) {
+ // Create an unsupported op with output shapes.
+ Model model;
+ EXPECT_TRUE(ImportNode(BuildNode("Atan", {{1, 2}, {2, 3}}), &model).ok());
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ // The output shapes should be imported.
+ ASSERT_EQ(op->output_shapes.size(), 2);
+ ASSERT_THAT(op->output_shapes[0].dims(), ::testing::ElementsAre(1, 2));
+ ASSERT_THAT(op->output_shapes[1].dims(), ::testing::ElementsAre(2, 3));
+}
+
+TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) {
+ // Create an unsupported op with wildcard output shapes.
+ Model model;
+ EXPECT_TRUE(ImportNode(BuildNode("Atan", {{-1, 2}}), &model).ok());
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ // Wildcard shapes aren't yet supported.
+ ASSERT_TRUE(op->output_shapes.empty());
+}
+
} // namespace
} // namespace toco