aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/cast_test.cc
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-06-27 18:11:03 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit0bf43348c6269cf46b3e16f93831fa05f226b896 (patch)
tree11b030ed738cb9b066114087361ef202eb6a7b90 /tensorflow/contrib/lite/kernels/cast_test.cc
parent54edcf739f928d4314e410c050abfc79f27bad38 (diff)
Add complex64 support to tf.lite runtime.
PiperOrigin-RevId: 202403235
Diffstat (limited to 'tensorflow/contrib/lite/kernels/cast_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/cast_test.cc67
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc
index 53e2000737..954f998206 100644
--- a/tensorflow/contrib/lite/kernels/cast_test.cc
+++ b/tensorflow/contrib/lite/kernels/cast_test.cc
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <complex>
+
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
@@ -73,6 +75,71 @@ TEST(CastOpModel, CastBoolToFloat) {
ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
}
+TEST(CastOpModel, CastComplex64ToFloat) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
+}
+
+TEST(CastOpModel, CastFloatToComplex64) {
+ CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
+ std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
+ std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
+}
+
+TEST(CastOpModel, CastComplex64ToInt) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int>(m.output()),
+ ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(CastOpModel, CastIntToComplex64) {
+ CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
+ std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
+ std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
+}
+
+TEST(CastOpModel, CastComplex64ToComplex64) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f),
+ std::complex<float>(6.0f, 16.0f)}));
+}
+
} // namespace
} // namespace tflite
int main(int argc, char** argv) {