diff options
Diffstat (limited to 'tensorflow/core/framework/type_traits.h')
-rw-r--r-- | tensorflow/core/framework/type_traits.h | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/core/framework/type_traits.h b/tensorflow/core/framework/type_traits.h new file mode 100644 index 0000000000..d87b6ff49b --- /dev/null +++ b/tensorflow/core/framework/type_traits.h @@ -0,0 +1,69 @@ +#ifndef TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ +#define TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ + +#include <limits> +#include <utility> + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Functions to define quantization attribute of types. +struct true_type { + static const bool value = true; +}; +struct false_type { + static const bool value = false; +}; + +// Default is_quantized is false. +template <typename T> +struct is_quantized : false_type {}; + +// Specialize the quantized types. +template <> +struct is_quantized<qint8> : true_type {}; +template <> +struct is_quantized<quint8> : true_type {}; +template <> +struct is_quantized<qint32> : true_type {}; + +// All types not specialized are marked invalid. +template <class T> +struct IsValidDataType { + static constexpr bool value = false; +}; + +// Extra validity checking; not part of public API. +struct TestIsValidDataType { + static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64"); + static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32"); +}; + +} // namespace tensorflow + +// Define numeric limits for our quantized as subclasses of the +// standard types. +namespace std { +template <> +class numeric_limits<tensorflow::qint8> + : public numeric_limits<tensorflow::int8> {}; +template <> +class numeric_limits<tensorflow::quint8> + : public numeric_limits<tensorflow::uint8> {}; +template <> +class numeric_limits<tensorflow::qint32> + : public numeric_limits<tensorflow::int32> {}; + +// Specialize is_signed for quantized types. +template <> +struct is_signed<tensorflow::qint8> : public is_signed<tensorflow::int8> {}; +template <> +struct is_signed<tensorflow::quint8> : public is_signed<tensorflow::uint8> {}; +template <> +struct is_signed<tensorflow::qint32> : public is_signed<tensorflow::int32> {}; + +} // namespace std + +#endif // TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ |