aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/types_test.cc
blob: eb92600397517d1290b12dee3ece7ca08bbf2b14 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "tensorflow/core/framework/types.h"

#include <gtest/gtest.h>
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/platform/protobuf.h"

namespace tensorflow {
namespace {

TEST(TypesTest, DeviceTypeName) {
  EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU)));
  EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU)));
}

TEST(TypesTest, kDataTypeRefOffset) {
  // Basic sanity check
  EXPECT_EQ(DT_FLOAT + kDataTypeRefOffset, DT_FLOAT_REF);

  // Use the meta-data provided by proto2 to iterate through the basic
  // types and validate that adding kDataTypeRefOffset gives the
  // corresponding reference type.
  const auto* enum_descriptor = DataType_descriptor();
  int e = DataType_MIN;
  if (e == DT_INVALID) ++e;
  int e_ref = e + kDataTypeRefOffset;
  EXPECT_FALSE(DataType_IsValid(e_ref - 1))
      << "Reference enum "
      << enum_descriptor->FindValueByNumber(e_ref - 1)->name()
      << " without corresponding base enum with value " << e - 1;
  for (;
       DataType_IsValid(e) && DataType_IsValid(e_ref) && e_ref <= DataType_MAX;
       ++e, ++e_ref) {
    string enum_name = enum_descriptor->FindValueByNumber(e)->name();
    string enum_ref_name = enum_descriptor->FindValueByNumber(e_ref)->name();
    EXPECT_EQ(enum_name + "_REF", enum_ref_name)
        << enum_name << "_REF should have value " << e_ref << " not "
        << enum_ref_name;
    // Validate DataTypeString() as well.
    DataType dt_e = static_cast<DataType>(e);
    DataType dt_e_ref = static_cast<DataType>(e_ref);
    EXPECT_EQ(DataTypeString(dt_e) + "_ref", DataTypeString(dt_e_ref));

    // Test DataTypeFromString reverse conversion
    DataType dt_e2, dt_e2_ref;
    EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e), &dt_e2));
    EXPECT_EQ(dt_e, dt_e2);
    EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e_ref), &dt_e2_ref));
    EXPECT_EQ(dt_e_ref, dt_e2_ref);
  }
  ASSERT_FALSE(DataType_IsValid(e))
      << "Should define " << enum_descriptor->FindValueByNumber(e)->name()
      << "_REF to be " << e_ref;
  ASSERT_FALSE(DataType_IsValid(e_ref))
      << "Extra reference enum "
      << enum_descriptor->FindValueByNumber(e_ref)->name()
      << " without corresponding base enum with value " << e;
  ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for "
                                 << e_ref;

  // Make sure there are no enums defined after the last regular type before
  // the first reference type.
  for (; e < DataType_MIN + kDataTypeRefOffset; ++e) {
    EXPECT_FALSE(DataType_IsValid(e))
        << "Discontinuous enum value "
        << enum_descriptor->FindValueByNumber(e)->name() << " = " << e;
  }
}

TEST(TypesTest, DataTypeFromString) {
  DataType dt;
  ASSERT_TRUE(DataTypeFromString("int32", &dt));
  EXPECT_EQ(DT_INT32, dt);
  ASSERT_TRUE(DataTypeFromString("int32_ref", &dt));
  EXPECT_EQ(DT_INT32_REF, dt);
  EXPECT_FALSE(DataTypeFromString("int32_ref_ref", &dt));
  EXPECT_FALSE(DataTypeFromString("foo", &dt));
  EXPECT_FALSE(DataTypeFromString("foo_ref", &dt));
  ASSERT_TRUE(DataTypeFromString("int64", &dt));
  EXPECT_EQ(DT_INT64, dt);
  ASSERT_TRUE(DataTypeFromString("int64_ref", &dt));
  EXPECT_EQ(DT_INT64_REF, dt);
  ASSERT_TRUE(DataTypeFromString("quint8_ref", &dt));
  EXPECT_EQ(DT_QUINT8_REF, dt);
  ASSERT_TRUE(DataTypeFromString("bfloat16", &dt));
  EXPECT_EQ(DT_BFLOAT16, dt);
}

template <typename T>
static bool GetQuantized() {
  return is_quantized<T>::value;
}

TEST(TypesTest, QuantizedTypes) {
  // NOTE: GUnit cannot parse is::quantized<TYPE>::value() within the
  // EXPECT_TRUE() clause, so we delegate through a template function.
  EXPECT_TRUE(GetQuantized<qint8>());
  EXPECT_TRUE(GetQuantized<quint8>());
  EXPECT_TRUE(GetQuantized<qint32>());

  EXPECT_FALSE(GetQuantized<int8>());
  EXPECT_FALSE(GetQuantized<uint8>());
  EXPECT_FALSE(GetQuantized<int16>());
  EXPECT_FALSE(GetQuantized<int32>());

  EXPECT_TRUE(DataTypeIsQuantized(DT_QINT8));
  EXPECT_TRUE(DataTypeIsQuantized(DT_QUINT8));
  EXPECT_TRUE(DataTypeIsQuantized(DT_QINT32));

  EXPECT_FALSE(DataTypeIsQuantized(DT_INT8));
  EXPECT_FALSE(DataTypeIsQuantized(DT_UINT8));
  EXPECT_FALSE(DataTypeIsQuantized(DT_INT16));
  EXPECT_FALSE(DataTypeIsQuantized(DT_INT32));
  EXPECT_FALSE(DataTypeIsQuantized(DT_BFLOAT16));
}

}  // namespace
}  // namespace tensorflow