diff options
author | Jared Duke <jdduke@google.com> | 2018-10-08 15:09:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 15:13:50 -0700 |
commit | b055d78b0edbf117ec5f7f2662d3bb2781ae02b3 (patch) | |
tree | 0cf6e49ab301aa03f8231a72cb528452ef2513b1 /tensorflow/contrib/lite/toco/import_tensorflow_test.cc | |
parent | 5da3cebe00111aa43e34b5a3fc12d1a97b838ba7 (diff) |
Fix issue with type inference for ops with fixed output types
Use the ArgDef::type field when available for propagating
the output types from a given unsupported operator.
PiperOrigin-RevId: 216257741
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow_test.cc | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index 8a236d4444..cd9a144b52 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -235,6 +235,21 @@ TEST_P(TypeImportTest, BasicTypeInference) { INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest, ::testing::ValuesIn(UnaryTestTypes())); +TEST(ImportTest, TypeInferenceWithFixedOutputType) { + // Create an op that has a fixed output type (bool). + Model model; + EXPECT_TRUE(ImportNode(BuildNode("IsFinite", {{1, 2}, {2, 3}}), &model).ok()); + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast<const TensorFlowUnsupportedOperator*>( + model.operators[0].get()); + + // The static output type should be indicated in the imported op. + ASSERT_THAT(op->output_data_types, + ::testing::ElementsAre(ArrayDataType::kBool)); +} + TEST(ImportTest, FailedTypeInference) { // Create a unary op with no Type ("T") annotation. NodeDef node; |