diff options
Diffstat (limited to 'tensorflow/core/kernels/cast_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/cast_op_test.cc | 45 |
1 files changed, 35 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index 7da9d28a3d..cb305de5e3 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -40,17 +40,27 @@ static Graph* Cast(int num) { class CastOpTest : public OpsTestBase { protected: - void MakeOp(DataType src, DataType dst) { - TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") - .Input(FakeInput(src)) - .Attr("SrcT", src) - .Attr("DstT", dst) - .Finalize(node_def())); + void MakeOp(DataType src, DataType dst, bool trunc = false) { + if (trunc) { + TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") + .Input(FakeInput(src)) + .Attr("SrcT", src) + .Attr("DstT", dst) + .Attr("Truncate", true) + .Finalize(node_def())); + } else { + TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") + .Input(FakeInput(src)) + .Attr("SrcT", src) + .Attr("DstT", dst) + .Finalize(node_def())); + } + TF_EXPECT_OK(InitOp()); } template <typename INPUT, typename OUTPUT> - void CheckCast() { + void CheckCast(bool trunc = false) { DataType in_type = DataTypeToEnum<INPUT>::v(); DataType out_type = DataTypeToEnum<OUTPUT>::v(); MakeOp(in_type, out_type); @@ -64,22 +74,32 @@ class CastOpTest : public OpsTestBase { } }; -#define TEST_CAST(in, out) \ - TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } +#define TEST_CAST(in, out) \ + TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } \ + TEST_F(CastOpTest, TestCast2##_##in##_##out) { CheckCast<in, out>(true); } #define TEST_ALL_CASTS_FROM(in) \ TEST_CAST(in, uint8); \ TEST_CAST(in, uint16); \ + TEST_CAST(in, uint32); \ + TEST_CAST(in, uint64); \ TEST_CAST(in, int16); \ TEST_CAST(in, int32); \ TEST_CAST(in, int64); \ TEST_CAST(in, half); \ TEST_CAST(in, float); \ TEST_CAST(in, double); \ - TEST_CAST(in, bfloat16); + TEST_CAST(in, bfloat16); \ + TEST_CAST(in, quint8); \ + TEST_CAST(in, qint8); \ + TEST_CAST(in, qint32); \ + TEST_CAST(in, qint16); \ + TEST_CAST(in, quint16); TEST_ALL_CASTS_FROM(uint8) TEST_ALL_CASTS_FROM(uint16) +TEST_ALL_CASTS_FROM(uint32) +TEST_ALL_CASTS_FROM(uint64) TEST_ALL_CASTS_FROM(int16) TEST_ALL_CASTS_FROM(int32) TEST_ALL_CASTS_FROM(int64) @@ -87,6 +107,11 @@ TEST_ALL_CASTS_FROM(half) TEST_ALL_CASTS_FROM(float) TEST_ALL_CASTS_FROM(double) TEST_ALL_CASTS_FROM(bfloat16) +TEST_ALL_CASTS_FROM(quint8) +TEST_ALL_CASTS_FROM(qint8) +TEST_ALL_CASTS_FROM(qint32) +TEST_ALL_CASTS_FROM(qint16) +TEST_ALL_CASTS_FROM(quint16) #undef TEST_ALL_CASTS_FROM #undef TEST_CAST |