aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-13 13:24:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-13 13:28:54 -0700
commitfc5885026be1fb2feb6a9ac27c6b8dc594e48ea7 (patch)
tree30a4ad1835d0c424a1d16d3f1dd91610cb4ca93c /tensorflow/core
parent571ed3ba4f8734ed891e81ad2b6bb9aadb816218 (diff)
Break the dependency between platform/types.h and bfloat16.h, and between
hash.h and bfloat16.h. This change introduces a generic mechanism for adapting types that are meant to be used in tensorflow's error objects. PiperOrigin-RevId: 188920678
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/framework/numeric_types.h1
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h3
-rw-r--r--tensorflow/core/lib/core/errors.h35
-rw-r--r--tensorflow/core/lib/hash/hash.h7
-rw-r--r--tensorflow/core/lib/random/random_distributions.h1
-rw-r--r--tensorflow/core/lib/strings/strcat.h3
-rw-r--r--tensorflow/core/platform/types.h2
7 files changed, 38 insertions, 14 deletions
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index 4c38fbbe59..dab53cba3e 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
// clang-format on
+#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index de8f92d1eb..6a1cc0994f 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -19,6 +19,9 @@ limitations under the License.
#include <cmath>
#include <complex>
+// We need types.h here in order to pick up __BYTE_ORDER__ from cpu_info.h
+#include "tensorflow/core/platform/types.h"
+
#ifdef __CUDACC__
// All functions callable from CUDA code must be qualified with __device__
#define B16_DEVICE_FUNC __host__ __device__
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index 1fd62755d8..1a0f4be2ea 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_CORE_ERRORS_H_
#define TENSORFLOW_LIB_CORE_ERRORS_H_
+#include <sstream>
+
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -26,6 +28,33 @@ namespace errors {
typedef ::tensorflow::error::Code Code;
+namespace internal {
+
+// The DECLARE_ERROR macro below only supports types that can be converted
+// into StrCat's AlphaNum. For the other types we rely on a slower path
+// through std::stringstream. To add support of a new type, it is enough to
+// make sure there is an operator<<() for it:
+//
+// std::ostream& operator<<(std::ostream& os, const MyType& foo) {
+// os << foo.ToString();
+// return os;
+// }
+// Eventually absl::strings will have native support for this and we will be
+// able to completely remove PrepareForStrCat().
+template <typename T>
+typename std::enable_if<!std::is_convertible<T, strings::AlphaNum>::value,
+ string>::type
+PrepareForStrCat(const T& t) {
+ std::stringstream ss;
+ ss << t;
+ return ss.str();
+}
+inline const strings::AlphaNum& PrepareForStrCat(const strings::AlphaNum& a) {
+ return a;
+}
+
+} // namespace internal
+
// Append some context to an error message. Each time we append
// context put it on a new line, since it is possible for there
// to be several layers of additional context.
@@ -61,8 +90,10 @@ void AppendToMessage(::tensorflow::Status* status, Args... args) {
#define DECLARE_ERROR(FUNC, CONST) \
template <typename... Args> \
::tensorflow::Status FUNC(Args... args) { \
- return ::tensorflow::Status(::tensorflow::error::CONST, \
- ::tensorflow::strings::StrCat(args...)); \
+ return ::tensorflow::Status( \
+ ::tensorflow::error::CONST, \
+ ::tensorflow::strings::StrCat( \
+ ::tensorflow::errors::internal::PrepareForStrCat(args)...)); \
} \
inline bool Is##FUNC(const ::tensorflow::Status& status) { \
return status.code() == ::tensorflow::error::CONST; \
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
index b90c6514f2..77b8031598 100644
--- a/tensorflow/core/lib/hash/hash.h
+++ b/tensorflow/core/lib/hash/hash.h
@@ -64,13 +64,6 @@ struct hash<T*> {
};
template <>
-struct hash<bfloat16> {
- size_t operator()(const bfloat16& t) const {
- return std::hash<float>()(static_cast<float>(t));
- }
-};
-
-template <>
struct hash<string> {
size_t operator()(const string& s) const {
return static_cast<size_t>(Hash64(s));
diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h
index 2ebe608fc9..ad16dbf01f 100644
--- a/tensorflow/core/lib/random/random_distributions.h
+++ b/tensorflow/core/lib/random/random_distributions.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <algorithm>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/random/philox_random.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index 2bc14945cd..b681f7398d 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -119,9 +119,6 @@ class AlphaNum {
AlphaNum(float f) // NOLINT(runtime/explicit)
: piece_(digits_, strlen(FloatToBuffer(f, digits_))) {}
- AlphaNum(bfloat16 f) // NOLINT(runtime/explicit)
- : piece_(digits_, strlen(FloatToBuffer(static_cast<float>(f), digits_))) {
- }
AlphaNum(double f) // NOLINT(runtime/explicit)
: piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {}
diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h
index e2dd5b003f..38d75dbb32 100644
--- a/tensorflow/core/platform/types.h
+++ b/tensorflow/core/platform/types.h
@@ -35,8 +35,6 @@ limitations under the License.
#include "tensorflow/core/platform/windows/cpu_info.h"
#endif
-#include "tensorflow/core/lib/bfloat16/bfloat16.h"
-
namespace tensorflow {
// Define tensorflow::string to refer to appropriate platform specific type.