aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/delegate_test.cc')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index eb47f46c0b..984f8bbc98 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -72,6 +72,26 @@ TEST_F(DelegateTest, FullGraph) {
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+ ASSERT_EQ(GetType(8), kTfLiteFloat32);
+}
+
+TEST_F(DelegateTest, NonFloatTypeInference) {
+ AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
+
+ AddTfOp(testing::kAdd, {0, 1}, {2});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2});
+ SetTypedValues<int>(0, {1, 2, 3, 4});
+ SetShape(1, {2, 2});
+ SetTypedValues<int>(1, {4, 3, 2, 1});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
+ ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
+ ASSERT_EQ(GetType(2), kTfLiteInt32);
}
TEST_F(DelegateTest, MixedGraph) {