diff options
author | Jared Duke <jdduke@google.com> | 2018-09-14 11:42:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-14 11:45:56 -0700 |
commit | 39f50af5634b8a4d2132b57bad2152308a0fd41c (patch) | |
tree | 5a5d0b0a9722067b702995dc84a1c4d8156d36a4 /tensorflow/contrib | |
parent | c20a7b81d79d30db9e990309ddb419bcb48120cc (diff) |
Improve output parsing for unsupported ops
PiperOrigin-RevId: 213017532
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 82 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow_test.cc | 52 |
2 files changed, 104 insertions, 30 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index efc1007925..2ccfd36b7c 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -69,6 +69,13 @@ bool HasAttr(const NodeDef& node, const string& attr_name) { return node.attr().count(attr_name) > 0; } +bool HasWildcardDimension(const TensorShapeProto& shape) { + for (const auto& dim : shape.dim()) { + if (dim.size() == -1) return true; + } + return false; +} + const string& GetStringAttr(const NodeDef& node, const string& attr_name) { CHECK(HasAttr(node, attr_name)); const auto& attr = node.attr().at(attr_name); @@ -1054,15 +1061,27 @@ tensorflow::Status ConvertUnsupportedOperator( "_support_output_type_float_in_quantized_op"; LOG(INFO) << "Converting unsupported operation: " << node.op(); + auto* op = new TensorFlowUnsupportedOperator; + op->tensorflow_op = node.op(); + node.SerializeToString(&op->tensorflow_node_def); + model->operators.emplace_back(op); + + // Parse inputs. const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); } - op->outputs.push_back(node.name()); - op->tensorflow_op = node.op(); - node.SerializeToString(&op->tensorflow_node_def); - model->operators.emplace_back(op); + + // Parse outputs. + op->outputs.push_back(node.name()); // Implicit :0. + const tensorflow::OpDef* op_def = nullptr; + if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { + for (int i = 1; i < op_def->output_arg_size(); ++i) { + op->outputs.push_back(absl::StrCat(node.name(), ":", i)); + } + } + // Parse if the op supports quantization if (HasAttr(node, kAttrOutputQuantized)) { op->quantized = GetBoolAttr(node, kAttrOutputQuantized); @@ -1072,6 +1091,8 @@ tensorflow::Status ConvertUnsupportedOperator( op->support_output_type_float_in_quantized_op = GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp); } + + // Parse output type(s). if (HasAttr(node, kAttrOutputTypes)) { const auto& output_types = GetListAttr(node, kAttrOutputTypes); for (int i = 0; i < output_types.type_size(); ++i) { @@ -1080,33 +1101,40 @@ tensorflow::Status ConvertUnsupportedOperator( } else if (HasAttr(node, "Tout")) { const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); - } else { - const tensorflow::OpDef* op_def = nullptr; - if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { - for (const auto& output_arg : op_def->output_arg()) { - if (HasAttr(node, output_arg.type_attr())) { - op->output_data_types.push_back( - ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr()))); - } else { - LOG(INFO) << "Op node missing output type attribute: " << node.name(); - op->output_data_types.clear(); - break; - } + } else if (op_def != nullptr) { + for (const auto& output_arg : op_def->output_arg()) { + if (HasAttr(node, output_arg.type_attr())) { + op->output_data_types.push_back( + ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr()))); + } else { + LOG(INFO) << "Op node missing output type attribute: " << node.name(); + op->output_data_types.clear(); + break; } } - if (op->output_data_types.empty()) { - // TODO(b/113613439): Figure out how to propagate types for custom ops - // that have no OpDef. - LOG(INFO) << "Unable to determine output type for op: " << node.op(); - } + } else { + // TODO(b/113613439): Figure out how to propagate types for custom ops + // that have no OpDef. + LOG(INFO) << "Unable to determine output type for op: " << node.op(); } + + // Parse output shape(s). if (HasAttr(node, kAttrOutputShapes)) { const auto& output_shapes = GetListAttr(node, kAttrOutputShapes); Shape output_shape; for (int i = 0; i < output_shapes.shape_size(); ++i) { + const auto& shape = output_shapes.shape(i); + // TOCO doesn't yet properly handle shapes with wildcard dimensions. + // TODO(b/113613439): Handle shape inference for unsupported ops that have + // shapes with wildcard dimensions. + if (HasWildcardDimension(shape)) { + LOG(INFO) << "Skipping wildcard output shape(s) for node: " + << node.name(); + op->output_shapes.clear(); + break; + } const auto status = - ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr, - &output_shape); + ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape); if (!status.ok()) { return status; } @@ -1159,15 +1187,9 @@ tensorflow::Status ConvertPlaceholderOperator( if (node.attr().count("shape")) { const auto& shape = GetShapeAttr(node, "shape"); auto num_dims = shape.dim_size(); - bool has_wildcard = false; - for (std::size_t i = 0; i < num_dims; i++) { - if (shape.dim(i).size() == -1) { - has_wildcard = true; - } - } // TODO(b/62716978): This logic needs to be revisted. During dims // refactoring it is an interim fix. - if (num_dims > 0 && !has_wildcard) { + if (num_dims > 0 && !HasWildcardDimension(shape)) { auto& dst_array_dims = *array.mutable_shape()->mutable_dims(); dst_array_dims.resize(num_dims); for (std::size_t i = 0; i < num_dims; i++) { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index da248826a7..8a236d4444 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -60,6 +60,28 @@ Status ImportNode(const NodeDef& node) { return ImportNode(node, &model); } +NodeDef BuildNode( + const std::string& op, + const std::vector<std::initializer_list<int>>& output_shapes) { + NodeDef node; + node.set_op(op); + node.set_name("Node1"); + node.add_input(); + node.set_input(0, "Node0"); + + AttrValue::ListValue* shapes = + (*node.mutable_attr())["_output_shapes"].mutable_list(); + for (const auto& output_shape : output_shapes) { + tensorflow::TensorShapeProto* shape = shapes->add_shape(); + for (int64_t output_shape_dim : output_shape) { + auto shape_dim = shape->add_dim(); + shape_dim->set_size(output_shape_dim); + } + } + + return node; +} + class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> { protected: ShapeImportTest() {} @@ -232,5 +254,35 @@ TEST(ImportTest, FailedTypeInference) { ASSERT_TRUE(op->output_data_types.empty()); } +TEST(ImportTest, UnsupportedOpWithOutputShapes) { + // Create an unsupported op with output shapes. + Model model; + EXPECT_TRUE(ImportNode(BuildNode("Atan", {{1, 2}, {2, 3}}), &model).ok()); + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast<const TensorFlowUnsupportedOperator*>( + model.operators[0].get()); + + // The output shapes should be imported. + ASSERT_EQ(op->output_shapes.size(), 2); + ASSERT_THAT(op->output_shapes[0].dims(), ::testing::ElementsAre(1, 2)); + ASSERT_THAT(op->output_shapes[1].dims(), ::testing::ElementsAre(2, 3)); +} + +TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) { + // Create an unsupported op with wildcard output shapes. + Model model; + EXPECT_TRUE(ImportNode(BuildNode("Atan", {{-1, 2}}), &model).ok()); + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast<const TensorFlowUnsupportedOperator*>( + model.operators[0].get()); + + // Wildcard shapes aren't yet supported. + ASSERT_TRUE(op->output_shapes.empty()); +} + } // namespace } // namespace toco |