aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/cast_test.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-05-17 14:58:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-17 15:00:50 -0700
commitfacd8f50733a398cc0ee08dfe76ad6b4f9e61817 (patch)
treed53c825abb8615eb8020e3222477750b770732ab /tensorflow/contrib/lite/kernels/cast_test.cc
parent695c97c3ddf73245ceeb9884eb4bc7d86f44532e (diff)
Support Bool in Cast (TFLite)
PiperOrigin-RevId: 197056978
Diffstat (limited to 'tensorflow/contrib/lite/kernels/cast_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/cast_test.cc16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc
index 4e56482a37..53e2000737 100644
--- a/tensorflow/contrib/lite/kernels/cast_test.cc
+++ b/tensorflow/contrib/lite/kernels/cast_test.cc
@@ -57,6 +57,22 @@ TEST(CastOpModel, CastFloatToInt) {
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
+TEST(CastOpModel, CastFloatToBool) {
+ CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}});
+ m.PopulateTensor<float>(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<bool>(m.output()),
+ ElementsAreArray({true, true, false, true, true, true}));
+}
+
+TEST(CastOpModel, CastBoolToFloat) {
+ CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
+ m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, true});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
+}
+
} // namespace
} // namespace tflite
int main(int argc, char** argv) {