aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-10-08 15:09:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 15:13:50 -0700
commitb055d78b0edbf117ec5f7f2662d3bb2781ae02b3 (patch)
tree0cf6e49ab301aa03f8231a72cb528452ef2513b1
parent5da3cebe00111aa43e34b5a3fc12d1a97b838ba7 (diff)
Fix issue with type inference for ops with fixed output types
Use the ArgDef::type field when available for propagating the output types from a given unsupported operator. PiperOrigin-RevId: 216257741
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc7
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc15
2 files changed, 20 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 133ef79a34..32f22e1ea0 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1151,11 +1151,14 @@ tensorflow::Status ConvertUnsupportedOperator(
op->output_data_types.push_back(ConvertDataType(output_type));
} else if (op_def != nullptr) {
for (const auto& output_arg : op_def->output_arg()) {
- if (HasAttr(node, output_arg.type_attr())) {
+ if (output_arg.type() != tensorflow::DT_INVALID) {
+ op->output_data_types.push_back(ConvertDataType(output_arg.type()));
+ } else if (HasAttr(node, output_arg.type_attr())) {
op->output_data_types.push_back(
ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
} else {
- LOG(INFO) << "Op node missing output type attribute: " << node.name();
+ LOG(WARNING) << "Op node missing output type attribute: "
+ << node.name();
op->output_data_types.clear();
break;
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index 8a236d4444..cd9a144b52 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -235,6 +235,21 @@ TEST_P(TypeImportTest, BasicTypeInference) {
INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest,
::testing::ValuesIn(UnaryTestTypes()));
+TEST(ImportTest, TypeInferenceWithFixedOutputType) {
+ // Create an op that has a fixed output type (bool).
+ Model model;
+ EXPECT_TRUE(ImportNode(BuildNode("IsFinite", {{1, 2}, {2, 3}}), &model).ok());
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ // The static output type should be indicated in the imported op.
+ ASSERT_THAT(op->output_data_types,
+ ::testing::ElementsAre(ArrayDataType::kBool));
+}
+
TEST(ImportTest, FailedTypeInference) {
// Create a unary op with no Type ("T") annotation.
NodeDef node;