diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal_comparison.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal_comparison.cc | 33 |
1 files changed, 14 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index f6ce69eaee..3d8725ed70 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -38,8 +38,8 @@ namespace { // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template <typename FloatT, typename UnsignedT> -Status CompareFloatsBitwiseEqual( - FloatT lhs, FloatT rhs, tensorflow::gtl::ArraySlice<int64> multi_index) { +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span<const int64> multi_index) { auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs); auto urhs = tensorflow::bit_cast<UnsignedT>(rhs); auto lhs_double = static_cast<double>(lhs); @@ -60,7 +60,7 @@ Status CompareFloatsBitwiseEqual( // default gunit implementation). template <typename NativeT> Status CompareEqual(NativeT lhs, NativeT rhs, - tensorflow::gtl::ArraySlice<int64> multi_index) { + absl::Span<const int64> multi_index) { if (lhs == rhs) { return Status::OK(); } @@ -74,28 +74,27 @@ Status CompareEqual(NativeT lhs, NativeT rhs, // comparison is requested. template <> Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs, - tensorflow::gtl::ArraySlice<int64> multi_index) { + absl::Span<const int64> multi_index) { return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index); } template <> -Status CompareEqual<Eigen::half>( - Eigen::half lhs, Eigen::half rhs, - tensorflow::gtl::ArraySlice<int64> multi_index) { +Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs, + absl::Span<const int64> multi_index) { return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index); } template <> Status CompareEqual<float>(float lhs, float rhs, - tensorflow::gtl::ArraySlice<int64> multi_index) { + absl::Span<const int64> multi_index) { return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index); } template <> Status CompareEqual<double>(double lhs, double rhs, - tensorflow::gtl::ArraySlice<int64> multi_index) { + absl::Span<const int64> multi_index) { return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index); } template <> Status CompareEqual<complex64>(complex64 lhs, complex64 rhs, - tensorflow::gtl::ArraySlice<int64> multi_index) { + absl::Span<const int64> multi_index) { auto res = CompareEqual<float>(lhs.real(), rhs.real(), multi_index); if (!res.ok()) { return res; @@ -108,8 +107,7 @@ Status CompareEqual<complex64>(complex64 lhs, complex64 rhs, // elements are equal. template <typename NativeT> Status Equal(LiteralSlice expected, LiteralSlice actual, - tensorflow::gtl::MutableArraySlice<int64> multi_index, - int64 dimension) { + absl::Span<int64> multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get<NativeT>(multi_index); NativeT actual_value = actual.Get<NativeT>(multi_index); @@ -305,8 +303,7 @@ class NearComparator { } // Insert the given error into the given error bucket vector. - void UpdateErrorBucket( - float error, tensorflow::gtl::MutableArraySlice<int64> error_buckets) { + void UpdateErrorBucket(float error, absl::Span<int64> error_buckets) { CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < error_buckets.size(); ++i) { if (error >= kErrorBucketBounds[i]) { @@ -410,10 +407,8 @@ class NearComparator { // Fast path optimization for the case were layouts match. if (LayoutUtil::Equal(actual_.shape().layout(), expected_.shape().layout())) { - tensorflow::gtl::ArraySlice<const NativeT> expected_data = - expected_.data<NativeT>(); - tensorflow::gtl::ArraySlice<const NativeT> actual_data = - actual_.data<NativeT>(); + absl::Span<const NativeT> expected_data = expected_.data<NativeT>(); + absl::Span<const NativeT> actual_data = actual_.data<NativeT>(); const int64 len = expected_data.size(); for (int64 i = 0; i < len; ++i) { CompareValues(expected_data[i], actual_data[i], i); @@ -488,7 +483,7 @@ class NearComparator { } auto print_accum_buckets = [&](const string& header, int64 total, - tensorflow::gtl::ArraySlice<int64> buckets) { + absl::Span<const int64> buckets) { StrAppend(&out, header, ":\n"); StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0], total - buckets[0], |