aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-13 09:33:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 09:37:23 -0700
commit88a7c5b98fc1ccb56134003ba3dc88a09385c0a7 (patch)
tree27b9a94786df3b6ed4afa82f34a00c89b447293d /tensorflow/compiler/tf2xla
parenta4bf3d0935570762e9d60eb917d8f42be7e398b4 (diff)
[TF:XLA] Make DataTypeToPrimitiveType work with all quantized types supported by TF
PiperOrigin-RevId: 212826065
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc85
-rw-r--r--tensorflow/compiler/tf2xla/type_util.cc11
2 files changed, 54 insertions, 42 deletions
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index ed452bceeb..15f4c38da2 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -22,48 +22,61 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
+namespace {
TEST(LiteralUtil, LiteralToHostTensor) {
// int64 literal can only be converted to an int64 host tensor.
- {
- std::vector<int64> int64_values = {1, 2, 3};
- xla::Literal int64_values_literal =
- xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
- Tensor host_tensor;
- EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
- LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
- .error_message());
- EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
- LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
- .error_message());
- EXPECT_TRUE(
- LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
- test::ExpectTensorEqual<int64>(host_tensor,
- test::AsTensor<int64>(int64_values));
- }
+ std::vector<int64> int64_values = {1, 2, 3};
+ xla::Literal int64_values_literal =
+ xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
+ Tensor host_tensor;
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
+ LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
+ .error_message());
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
+ LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
+ .error_message());
+ EXPECT_TRUE(
+ LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
+ test::ExpectTensorEqual<int64>(host_tensor,
+ test::AsTensor<int64>(int64_values));
+}
+
+template <class T>
+using LiteralUtilTest = ::testing::Test;
+using Types =
+ ::testing::Types<std::pair<int8, qint8>, std::pair<uint8, quint8>,
+ std::pair<int16, qint16>, std::pair<uint16, quint16>,
+ std::pair<int32, qint32>>;
+
+TYPED_TEST_CASE(LiteralUtilTest, Types);
+
+TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) {
+ using int_type = typename TypeParam::first_type;
+ using qint_type = typename TypeParam::second_type;
- {
- // Repeat tests with int32.
- Tensor host_tensor;
- std::vector<int32> int32_values = {10, 11};
- xla::Literal int32_values_literal =
- xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
- EXPECT_TRUE(
- LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok());
- test::ExpectTensorEqual<int32>(host_tensor,
- test::AsTensor<int32>(int32_values));
+ Tensor host_tensor;
+ std::vector<int_type> int_values = {10, 11};
+ xla::Literal int_values_literal =
+ xla::LiteralUtil::CreateR1(absl::Span<const int_type>(int_values));
+ EXPECT_TRUE(LiteralToHostTensor(int_values_literal,
+ DataTypeToEnum<int_type>::value, &host_tensor)
+ .ok());
+ test::ExpectTensorEqual<int_type>(host_tensor,
+ test::AsTensor<int_type>(int_values));
- EXPECT_TRUE(
- LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor)
- .ok());
- std::vector<qint32> qint32_values = {10, 11};
- test::ExpectTensorEqual<qint32>(host_tensor,
- test::AsTensor<qint32>(qint32_values));
+ EXPECT_TRUE(LiteralToHostTensor(int_values_literal,
+ DataTypeToEnum<qint_type>::value,
+ &host_tensor)
+ .ok());
+ std::vector<qint_type> qint_values = {10, 11};
+ test::ExpectTensorEqual<qint_type>(host_tensor,
+ test::AsTensor<qint_type>(qint_values));
- EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
- LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor)
- .error_message());
- }
+ EXPECT_EQ(
+ error::INVALID_ARGUMENT,
+ LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code());
}
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc
index c969212a1b..d00b137662 100644
--- a/tensorflow/compiler/tf2xla/type_util.cc
+++ b/tensorflow/compiler/tf2xla/type_util.cc
@@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
*type = xla::PRED;
return Status::OK();
case tensorflow::DT_INT8:
+ case tensorflow::DT_QINT8:
*type = xla::S8;
return Status::OK();
case tensorflow::DT_INT16:
+ case tensorflow::DT_QINT16:
*type = xla::S16;
return Status::OK();
case tensorflow::DT_INT32:
+ case tensorflow::DT_QINT32:
*type = xla::S32;
return Status::OK();
case tensorflow::DT_INT64:
*type = xla::S64;
return Status::OK();
case tensorflow::DT_UINT8:
+ case tensorflow::DT_QUINT8:
*type = xla::U8;
return Status::OK();
case tensorflow::DT_UINT16:
+ case tensorflow::DT_QUINT16:
*type = xla::U16;
return Status::OK();
case tensorflow::DT_UINT32:
@@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_COMPLEX64:
*type = xla::C64;
return Status::OK();
- case tensorflow::DT_QUINT8:
- *type = xla::U8;
- return Status::OK();
- case tensorflow::DT_QINT32:
- *type = xla::S32;
- return Status::OK();
default:
return errors::InvalidArgument(
"Unsupported type in DataTypeToPrimitiveType ",