aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/literal_util.h1
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc186
2 files changed, 138 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index e0196509a7..2b68b8f177 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -486,6 +486,7 @@ class Literal {
std::vector<std::unique_ptr<Literal>> elements);
// Returns a string representation of the literal value.
+ // Warning: this function can take minutes for multi-million element Literals.
string ToString(bool print_layout = false) const;
// Invokes the "per cell" callback for each element in the provided
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index f8205de702..39c07297d6 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -355,9 +355,9 @@ class NearComparator {
// temporary files on failure. Returns true if literals match.
bool ExpectNear(const Literal& expected, const Literal& actual) {
VLOG(1) << "expected:";
- XLA_VLOG_LINES(1, expected.ToString());
+ XLA_VLOG_LINES(1, TruncateHugeLiteral(expected));
VLOG(1) << "actual:";
- XLA_VLOG_LINES(1, actual.ToString());
+ XLA_VLOG_LINES(1, TruncateHugeLiteral(actual));
// If the shapes mismatch, we simply fail the expectation instead of
// printing out data, as it's a type error rather than a value error.
@@ -377,6 +377,7 @@ class NearComparator {
max_rel_err_ = 0.0;
max_abs_err_ = 0.0;
miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED));
+ miscompares_.PopulateWithValue(false);
multi_index_.resize(expected.shape().dimensions_size(), 0);
switch (expected.shape().element_type()) {
@@ -404,21 +405,33 @@ class NearComparator {
if (num_miscompares_ > 0) {
if (!VLOG_IS_ON(1)) {
LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape())
- << " " << expected.ToString();
+ << " " << TruncateHugeLiteral(expected);
LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape())
- << " " << actual.ToString();
+ << " " << TruncateHugeLiteral(actual);
+ LOG(INFO) << "Dumping literals to temp files...";
+ WriteLiteralToTempFile(expected, "expected");
+ WriteLiteralToTempFile(actual, "actual");
+ WriteLiteralToTempFile(miscompares_, "miscompares");
}
EXPECT_TRUE(num_miscompares_ == 0)
<< "\nmax relative mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(max_rel_multi_index_)
+ << LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(
+ actual.shape(), max_rel_linear_index_))
<< "\nmaximum relative error " << max_rel_err_
<< "\nmax absolute mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(max_abs_multi_index_)
+ << LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(
+ actual.shape(), max_abs_linear_index_))
<< "\nmaximum absolute error " << max_abs_err_
<< "\nfirst mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(first_multi_index_)
+ << LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(
+ actual.shape(), first_linear_index_))
<< "\nlast mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(last_multi_index_)
+ << LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(
+ actual.shape(), last_linear_index_))
<< "\ntotal absolute error " << abs_diff_sum_
<< "\ntotal absolute error of miscompares "
<< abs_diff_miscompare_sum_ << "\ntotal relative error "
@@ -426,10 +439,6 @@ class NearComparator {
<< "\ntotal relative error of miscompares "
<< (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_)
<< "\nfailure count " << num_miscompares_;
-
- WriteLiteralToTempFile(expected, "expected");
- WriteLiteralToTempFile(actual, "actual");
- WriteLiteralToTempFile(miscompares_, "miscompares");
}
return num_miscompares_ == 0;
}
@@ -457,57 +466,93 @@ class NearComparator {
return true;
}
- float abs_diff = std::abs(actual - expected);
- float rel_err = abs_diff / std::abs(expected);
+ const float abs_diff = std::abs(actual - expected);
+ const float rel_err = abs_diff / std::abs(expected);
+ const bool nan_mismatch = NanMismatch<NativeT>(expected, actual);
+ const bool mismatch =
+ (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
+ return !mismatch;
+ }
+
+ // Assumes that expected vs actual fail ExpectValuesNear.
+ template <typename NativeT>
+ void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual,
+ const Shape& shape, const int64 linear_index) {
+ const float abs_diff = std::abs(actual - expected);
+ const float rel_err = abs_diff / std::abs(expected);
abs_diff_sum_ += abs_diff;
abs_expected_sum_ += std::abs(expected);
if (rel_err > max_rel_err_) {
max_rel_err_ = rel_err;
- max_rel_multi_index_ = multi_index_;
+ max_rel_linear_index_ = linear_index;
}
if (abs_diff > max_abs_err_) {
max_abs_err_ = abs_diff;
- max_abs_multi_index_ = multi_index_;
+ max_abs_linear_index_ = linear_index;
}
- VLOG(10) << tensorflow::strings::Printf(
- "index %s abs_diff %f rel_err %f",
- LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff,
- rel_err);
- bool nan_mismatch = NanMismatch<NativeT>(expected, actual);
- bool mismatch =
- (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
- if (mismatch) {
- abs_diff_miscompare_sum_ += abs_diff;
- abs_expected_miscompare_sum_ += std::abs(expected);
- const int64 kMaxFailures = 2;
- if (num_miscompares_ < kMaxFailures) {
- ::testing::Message msg;
- msg << "mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff "
- << abs_diff << " rel err " << rel_err << " failure #"
- << num_miscompares_;
- ExpectNear<NativeT>(expected, actual, msg);
- } else if (num_miscompares_ == kMaxFailures) {
- LOG(ERROR)
- << "reached max 'loud' failure count; silently proceeding...";
- }
- if (num_miscompares_ == 0) {
- first_multi_index_ = multi_index_;
- }
- num_miscompares_++;
- last_multi_index_ = multi_index_;
+ if (VLOG_IS_ON(10)) {
+ VLOG(10) << tensorflow::strings::Printf(
+ "index %s abs_diff %f rel_err %f",
+ LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape,
+ linear_index))
+ .c_str(),
+ abs_diff, rel_err);
}
- return !mismatch;
+ abs_diff_miscompare_sum_ += abs_diff;
+ abs_expected_miscompare_sum_ += std::abs(expected);
+ const int64 kMaxFailures = 2;
+ if (num_miscompares_ < kMaxFailures) {
+ const auto multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index);
+ ::testing::Message msg;
+ msg << "mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff "
+ << abs_diff << " rel err " << rel_err << " failure #"
+ << num_miscompares_;
+ ExpectNear<NativeT>(expected, actual, msg);
+ } else if (num_miscompares_ == kMaxFailures) {
+ LOG(ERROR) << "reached max 'loud' failure count; silently proceeding...";
+ }
+ if (num_miscompares_ == 0) {
+ first_linear_index_ = linear_index;
+ }
+ num_miscompares_++;
+ last_linear_index_ = linear_index;
+ miscompares_.data<bool>()[linear_index] = true;
}
// Recursive function which compares the two given literals elementwise.
template <typename NativeT>
void ExpectLiteralsNear(const Literal& expected, const Literal& actual,
int64 dimension) {
+ // Fast path optimization for the case were layouts match.
+ if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) {
+ tensorflow::gtl::ArraySlice<const NativeT> expected_data =
+ expected.data<NativeT>();
+ tensorflow::gtl::ArraySlice<const NativeT> actual_data =
+ actual.data<NativeT>();
+ const int64 len = expected_data.size();
+ for (int64 i = 0; i < len; ++i) {
+ const bool near = ExpectValuesNear(expected_data[i], actual_data[i]);
+ if (!near) {
+ UpdateAndLogMiscompares<NativeT>(expected_data[i], actual_data[i],
+ actual.shape(), i);
+ }
+ }
+ return;
+ }
+
if (dimension == expected.shape().dimensions_size()) {
bool near = ExpectValuesNear(expected.Get<NativeT>(multi_index_),
actual.Get<NativeT>(multi_index_));
- miscompares_.Set<bool>(multi_index_, !near);
+ if (!near) {
+ UpdateAndLogMiscompares<NativeT>(
+ expected.Get<NativeT>(multi_index_),
+ actual.Get<NativeT>(multi_index_), actual.shape(),
+ IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(),
+ multi_index_));
+ }
} else {
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
multi_index_[dimension] = i;
@@ -528,6 +573,32 @@ class NearComparator {
LOG(ERROR) << "wrote to " << name << " file: " << filename;
}
+ // Gets the total element count. For tuples, this is not the count of tuple
+ // elements, but the sum of elements of each tuple element.
+ int64 RecursiveElementCount(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
+ int64 total = 0;
+ for (int64 i = 0; i < tuple_elements; ++i) {
+ total +=
+ RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
+ }
+ return total;
+ } else {
+ return ShapeUtil::ElementsIn(shape);
+ }
+ }
+
+ // Calling ToString on a literal with over 100 million elements takes around
+ // 3 minutes. The utility of printing a literal with >1000 elements is
+ // questionable, especially when writing the Literal proto to disk is orders
+ // of magnitude faster.
+ string TruncateHugeLiteral(const Literal& literal) {
+ return RecursiveElementCount(literal.shape()) < 1000
+ ? literal.ToString()
+ : "[TRUNCATED, Literal with more than 1000 values]";
+ }
+
ErrorSpec error_;
// Number of element miscomparisons encountered so far.
@@ -548,10 +619,10 @@ class NearComparator {
double abs_expected_miscompare_sum_;
float max_rel_err_;
float max_abs_err_;
- std::vector<int64> first_multi_index_;
- std::vector<int64> last_multi_index_;
- std::vector<int64> max_rel_multi_index_;
- std::vector<int64> max_abs_multi_index_;
+ int64 first_linear_index_;
+ int64 last_linear_index_;
+ int64 max_rel_linear_index_;
+ int64 max_abs_linear_index_;
};
template <>
@@ -584,6 +655,23 @@ bool NearComparator::ExpectValuesNear<half>(half expected, half actual) {
static_cast<float>(std::move(actual)));
}
+template <>
+void NearComparator::UpdateAndLogMiscompares<bfloat16>(
+ const bfloat16 expected, const bfloat16 actual, const Shape& shape,
+ const int64 linear_index) {
+ UpdateAndLogMiscompares(static_cast<float>(expected),
+ static_cast<float>(actual), shape, linear_index);
+}
+
+template <>
+void NearComparator::UpdateAndLogMiscompares<half>(half expected, half actual,
+ const Shape& shape,
+ const int64 linear_index) {
+ UpdateAndLogMiscompares(static_cast<float>(std::move(expected)),
+ static_cast<float>(std::move(actual)), shape,
+ linear_index);
+}
+
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(