From b055d78b0edbf117ec5f7f2662d3bb2781ae02b3 Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Mon, 8 Oct 2018 15:09:57 -0700 Subject: Fix issue with type inference for ops with fixed output types Use the ArgDef::type field when available for propagating the output types from a given unsupported operator. PiperOrigin-RevId: 216257741 --- tensorflow/contrib/lite/toco/import_tensorflow.cc | 7 +++++-- tensorflow/contrib/lite/toco/import_tensorflow_test.cc | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 133ef79a34..32f22e1ea0 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1151,11 +1151,14 @@ tensorflow::Status ConvertUnsupportedOperator( op->output_data_types.push_back(ConvertDataType(output_type)); } else if (op_def != nullptr) { for (const auto& output_arg : op_def->output_arg()) { - if (HasAttr(node, output_arg.type_attr())) { + if (output_arg.type() != tensorflow::DT_INVALID) { + op->output_data_types.push_back(ConvertDataType(output_arg.type())); + } else if (HasAttr(node, output_arg.type_attr())) { op->output_data_types.push_back( ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr()))); } else { - LOG(INFO) << "Op node missing output type attribute: " << node.name(); + LOG(WARNING) << "Op node missing output type attribute: " + << node.name(); op->output_data_types.clear(); break; } diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index 8a236d4444..cd9a144b52 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -235,6 +235,21 @@ TEST_P(TypeImportTest, BasicTypeInference) { INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest, ::testing::ValuesIn(UnaryTestTypes())); +TEST(ImportTest, TypeInferenceWithFixedOutputType) { + // Create an op that has a fixed output type (bool). + Model model; + EXPECT_TRUE(ImportNode(BuildNode("IsFinite", {{1, 2}, {2, 3}}), &model).ok()); + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast( + model.operators[0].get()); + + // The static output type should be indicated in the imported op. + ASSERT_THAT(op->output_data_types, + ::testing::ElementsAre(ArrayDataType::kBool)); +} + TEST(ImportTest, FailedTypeInference) { // Create a unary op with no Type ("T") annotation. NodeDef node; -- cgit v1.2.3