diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-09-13 09:33:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-13 09:37:23 -0700 |
commit | 88a7c5b98fc1ccb56134003ba3dc88a09385c0a7 (patch) | |
tree | 27b9a94786df3b6ed4afa82f34a00c89b447293d /tensorflow/compiler/tf2xla | |
parent | a4bf3d0935570762e9d60eb917d8f42be7e398b4 (diff) |
[TF:XLA] Make DataTypeToPrimitiveType work with all quantized types supported by TF
PiperOrigin-RevId: 212826065
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/literal_util_test.cc | 85 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/type_util.cc | 11 |
2 files changed, 54 insertions, 42 deletions
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index ed452bceeb..15f4c38da2 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -22,48 +22,61 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace { TEST(LiteralUtil, LiteralToHostTensor) { // int64 literal can only be converted to an int64 host tensor. - { - std::vector<int64> int64_values = {1, 2, 3}; - xla::Literal int64_values_literal = - xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values)); - Tensor host_tensor; - EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", - LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) - .error_message()); - EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", - LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); - EXPECT_TRUE( - LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); - test::ExpectTensorEqual<int64>(host_tensor, - test::AsTensor<int64>(int64_values)); - } + std::vector<int64> int64_values = {1, 2, 3}; + xla::Literal int64_values_literal = + xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values)); + Tensor host_tensor; + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", + LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) + .error_message()); + EXPECT_TRUE( + LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); + test::ExpectTensorEqual<int64>(host_tensor, + test::AsTensor<int64>(int64_values)); +} + +template <class T> +using LiteralUtilTest = ::testing::Test; +using Types = + ::testing::Types<std::pair<int8, qint8>, std::pair<uint8, quint8>, + std::pair<int16, qint16>, std::pair<uint16, quint16>, + std::pair<int32, qint32>>; + +TYPED_TEST_CASE(LiteralUtilTest, Types); + +TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { + using int_type = typename TypeParam::first_type; + using qint_type = typename TypeParam::second_type; - { - // Repeat tests with int32. - Tensor host_tensor; - std::vector<int32> int32_values = {10, 11}; - xla::Literal int32_values_literal = - xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values)); - EXPECT_TRUE( - LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok()); - test::ExpectTensorEqual<int32>(host_tensor, - test::AsTensor<int32>(int32_values)); + Tensor host_tensor; + std::vector<int_type> int_values = {10, 11}; + xla::Literal int_values_literal = + xla::LiteralUtil::CreateR1(absl::Span<const int_type>(int_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum<int_type>::value, &host_tensor) + .ok()); + test::ExpectTensorEqual<int_type>(host_tensor, + test::AsTensor<int_type>(int_values)); - EXPECT_TRUE( - LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor) - .ok()); - std::vector<qint32> qint32_values = {10, 11}; - test::ExpectTensorEqual<qint32>(host_tensor, - test::AsTensor<qint32>(qint32_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum<qint_type>::value, + &host_tensor) + .ok()); + std::vector<qint_type> qint_values = {10, 11}; + test::ExpectTensorEqual<qint_type>(host_tensor, + test::AsTensor<qint_type>(qint_values)); - EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", - LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor) - .error_message()); - } + EXPECT_EQ( + error::INVALID_ARGUMENT, + LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code()); } +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index c969212a1b..d00b137662 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { *type = xla::PRED; return Status::OK(); case tensorflow::DT_INT8: + case tensorflow::DT_QINT8: *type = xla::S8; return Status::OK(); case tensorflow::DT_INT16: + case tensorflow::DT_QINT16: *type = xla::S16; return Status::OK(); case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: *type = xla::S32; return Status::OK(); case tensorflow::DT_INT64: *type = xla::S64; return Status::OK(); case tensorflow::DT_UINT8: + case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); case tensorflow::DT_UINT16: + case tensorflow::DT_QUINT16: *type = xla::U16; return Status::OK(); case tensorflow::DT_UINT32: @@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); - case tensorflow::DT_QUINT8: - *type = xla::U8; - return Status::OK(); - case tensorflow::DT_QINT32: - *type = xla::S32; - return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", |