aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_comparison.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal_comparison.cc')
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc33
1 files changed, 14 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index f6ce69eaee..3d8725ed70 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -38,8 +38,8 @@ namespace {
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
// -- on miscompare, a nice error message is given in the AssertionFailure.
template <typename FloatT, typename UnsignedT>
-Status CompareFloatsBitwiseEqual(
- FloatT lhs, FloatT rhs, tensorflow::gtl::ArraySlice<int64> multi_index) {
+Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
+ absl::Span<const int64> multi_index) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
auto lhs_double = static_cast<double>(lhs);
@@ -60,7 +60,7 @@ Status CompareFloatsBitwiseEqual(
// default gunit implementation).
template <typename NativeT>
Status CompareEqual(NativeT lhs, NativeT rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
if (lhs == rhs) {
return Status::OK();
}
@@ -74,28 +74,27 @@ Status CompareEqual(NativeT lhs, NativeT rhs,
// comparison is requested.
template <>
Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<Eigen::half>(
- Eigen::half lhs, Eigen::half rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
+ absl::Span<const int64> multi_index) {
return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
}
template <>
Status CompareEqual<float>(float lhs, float rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
}
template <>
Status CompareEqual<double>(double lhs, double rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
}
template <>
Status CompareEqual<complex64>(complex64 lhs, complex64 rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
auto res = CompareEqual<float>(lhs.real(), rhs.real(), multi_index);
if (!res.ok()) {
return res;
@@ -108,8 +107,7 @@ Status CompareEqual<complex64>(complex64 lhs, complex64 rhs,
// elements are equal.
template <typename NativeT>
Status Equal(LiteralSlice expected, LiteralSlice actual,
- tensorflow::gtl::MutableArraySlice<int64> multi_index,
- int64 dimension) {
+ absl::Span<int64> multi_index, int64 dimension) {
if (dimension == expected.shape().dimensions_size()) {
NativeT expected_value = expected.Get<NativeT>(multi_index);
NativeT actual_value = actual.Get<NativeT>(multi_index);
@@ -305,8 +303,7 @@ class NearComparator {
}
// Insert the given error into the given error bucket vector.
- void UpdateErrorBucket(
- float error, tensorflow::gtl::MutableArraySlice<int64> error_buckets) {
+ void UpdateErrorBucket(float error, absl::Span<int64> error_buckets) {
CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
for (int i = 0; i < error_buckets.size(); ++i) {
if (error >= kErrorBucketBounds[i]) {
@@ -410,10 +407,8 @@ class NearComparator {
// Fast path optimization for the case were layouts match.
if (LayoutUtil::Equal(actual_.shape().layout(),
expected_.shape().layout())) {
- tensorflow::gtl::ArraySlice<const NativeT> expected_data =
- expected_.data<NativeT>();
- tensorflow::gtl::ArraySlice<const NativeT> actual_data =
- actual_.data<NativeT>();
+ absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
+ absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
const int64 len = expected_data.size();
for (int64 i = 0; i < len; ++i) {
CompareValues(expected_data[i], actual_data[i], i);
@@ -488,7 +483,7 @@ class NearComparator {
}
auto print_accum_buckets = [&](const string& header, int64 total,
- tensorflow::gtl::ArraySlice<int64> buckets) {
+ absl::Span<const int64> buckets) {
StrAppend(&out, header, ":\n");
StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0],
total - buckets[0],