aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 13:26:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 13:32:56 -0700
commit89dab281fdf9156af098937cadb058d29fa05227 (patch)
tree75fa580bd2fffdbd7c2c50cf7d87892465d26d33
parentdd1ce0fd8f8c97728475428f562825480e9d4194 (diff)
Adding int16 as a supported Select type.
PiperOrigin-RevId: 204001217
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/select_test.cc13
2 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 9b6cee3cb5..3cdb5db209 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: \
TF_LITE_SELECT(uint8_t, op); \
break; \
+ case kTfLiteInt16: \
+ TF_LITE_SELECT(int16_t, op); \
+ break; \
case kTfLiteInt32: \
TF_LITE_SELECT(int32_t, op); \
break; \
diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc
index 4664b9acb4..5b2e61cd29 100644
--- a/tensorflow/contrib/lite/kernels/select_test.cc
+++ b/tensorflow/contrib/lite/kernels/select_test.cc
@@ -96,6 +96,19 @@ TEST(SelectOpTest, SelectUInt8) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
+TEST(SelectOpTest, SelectInt16) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_INT16);
+
+ model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
+ model.PopulateTensor<int16_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int16_t>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
TEST(SelectOpTest, SelectInt32) {
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
TensorType_INT32);