aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cast_op_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cast_op_test.cc')
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc45
1 files changed, 35 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
index 7da9d28a3d..cb305de5e3 100644
--- a/tensorflow/core/kernels/cast_op_test.cc
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -40,17 +40,27 @@ static Graph* Cast(int num) {
class CastOpTest : public OpsTestBase {
protected:
- void MakeOp(DataType src, DataType dst) {
- TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
- .Input(FakeInput(src))
- .Attr("SrcT", src)
- .Attr("DstT", dst)
- .Finalize(node_def()));
+ void MakeOp(DataType src, DataType dst, bool trunc = false) {
+ if (trunc) {
+ TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(src))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Attr("Truncate", true)
+ .Finalize(node_def()));
+ } else {
+ TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(src))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Finalize(node_def()));
+ }
+
TF_EXPECT_OK(InitOp());
}
template <typename INPUT, typename OUTPUT>
- void CheckCast() {
+ void CheckCast(bool trunc = false) {
DataType in_type = DataTypeToEnum<INPUT>::v();
DataType out_type = DataTypeToEnum<OUTPUT>::v();
MakeOp(in_type, out_type);
@@ -64,22 +74,32 @@ class CastOpTest : public OpsTestBase {
}
};
-#define TEST_CAST(in, out) \
- TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); }
+#define TEST_CAST(in, out) \
+ TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } \
+ TEST_F(CastOpTest, TestCast2##_##in##_##out) { CheckCast<in, out>(true); }
#define TEST_ALL_CASTS_FROM(in) \
TEST_CAST(in, uint8); \
TEST_CAST(in, uint16); \
+ TEST_CAST(in, uint32); \
+ TEST_CAST(in, uint64); \
TEST_CAST(in, int16); \
TEST_CAST(in, int32); \
TEST_CAST(in, int64); \
TEST_CAST(in, half); \
TEST_CAST(in, float); \
TEST_CAST(in, double); \
- TEST_CAST(in, bfloat16);
+ TEST_CAST(in, bfloat16); \
+ TEST_CAST(in, quint8); \
+ TEST_CAST(in, qint8); \
+ TEST_CAST(in, qint32); \
+ TEST_CAST(in, qint16); \
+ TEST_CAST(in, quint16);
TEST_ALL_CASTS_FROM(uint8)
TEST_ALL_CASTS_FROM(uint16)
+TEST_ALL_CASTS_FROM(uint32)
+TEST_ALL_CASTS_FROM(uint64)
TEST_ALL_CASTS_FROM(int16)
TEST_ALL_CASTS_FROM(int32)
TEST_ALL_CASTS_FROM(int64)
@@ -87,6 +107,11 @@ TEST_ALL_CASTS_FROM(half)
TEST_ALL_CASTS_FROM(float)
TEST_ALL_CASTS_FROM(double)
TEST_ALL_CASTS_FROM(bfloat16)
+TEST_ALL_CASTS_FROM(quint8)
+TEST_ALL_CASTS_FROM(qint8)
+TEST_ALL_CASTS_FROM(qint32)
+TEST_ALL_CASTS_FROM(qint16)
+TEST_ALL_CASTS_FROM(quint16)
#undef TEST_ALL_CASTS_FROM
#undef TEST_CAST