diff options
author | RJ Ryan <rjryan@google.com> | 2018-06-27 18:11:03 -0700 |
---|---|---|
committer | Gunhan Gulsoy <gunan@google.com> | 2018-06-28 21:37:43 -0700 |
commit | 0bf43348c6269cf46b3e16f93831fa05f226b896 (patch) | |
tree | 11b030ed738cb9b066114087361ef202eb6a7b90 /tensorflow/contrib/lite/kernels/cast_test.cc | |
parent | 54edcf739f928d4314e410c050abfc79f27bad38 (diff) |
Add complex64 support to tf.lite runtime.
PiperOrigin-RevId: 202403235
Diffstat (limited to 'tensorflow/contrib/lite/kernels/cast_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/cast_test.cc | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc index 53e2000737..954f998206 100644 --- a/tensorflow/contrib/lite/kernels/cast_test.cc +++ b/tensorflow/contrib/lite/kernels/cast_test.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <complex> + #include <gtest/gtest.h> #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" @@ -73,6 +75,71 @@ TEST(CastOpModel, CastBoolToFloat) { ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); } +TEST(CastOpModel, CastComplex64ToFloat) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.PopulateTensor<std::complex<float>>( + m.input(), + {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), + std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), + std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector<float>(m.output()), + ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); +} + +TEST(CastOpModel, CastFloatToComplex64) { + CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector<std::complex<float>>(m.output()), + ElementsAreArray( + {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f), + std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f), + std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToInt) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}}); + m.PopulateTensor<std::complex<float>>( + m.input(), + {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), + std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), + std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector<int>(m.output()), + ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(CastOpModel, CastIntToComplex64) { + CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector<std::complex<float>>(m.output()), + ElementsAreArray( + {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f), + std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f), + std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToComplex64) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor<std::complex<float>>( + m.input(), + {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), + std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), + std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector<std::complex<float>>(m.output()), + ElementsAreArray( + {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), + std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), + std::complex<float>(5.0f, 15.0f), + std::complex<float>(6.0f, 16.0f)})); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { |