aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nick Desaulniers <ndesaulniers@google.com>2018-01-26 15:01:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 15:17:51 -0800
commit3e9bf0874ed19b1f96f835c444a4b80167de4663 (patch)
tree6cd016594ec079cc33b907a4aa5af5790f5356b5
parentb9494ce8990cab65a20f3d5110f4c2c4402342be (diff)
[XLA] optimize NearComparator#ExpectLiteralsNear()
While tracking down the issue of timeouts when running THE ISOLATOR, it was observed that NearComparator#ExpectLiteralsNear() could be optimized in the case of matching layouts to not compute multi indexes. In the process of tracking down timeouts in THE ISOLATOR, I had assumed that time spent was dominated by either generating input data, executing the input data on various backends, or comparing the data. Never assume you know where the time is spent in a program; the profiler may surprise you. After making that optimization and then profiling the code before and after, I was surprised by the profile. Image the shock, horror, and disgust I experienced when discovering that runs of THE ISOLATOR were dominated (45%) by calls to Literal#ToString() in NearComparator#ExpectLiteralsNear() for huge (>120 million elements) literals that failed comparisons. No wonder passing shards of THE ISOLATOR were fast, and failing shards were slow. Further, computing multi indexes many times is expensive enough (18%) to show up in profiles, so avoid calculating it until it is necessary. The optimizations in this patch: * Don't call Literal#ToString() on huge literals that are going to get written to disk anyways. The utility of printing said literal to stdout is suspect. * Initialize NearComparator#miscompares_ to false, only update miscompares_ and other stats when miscompare occurs. * Split NearComparator#ExpectLiteralsNear() into two, since we only need to log and update stats if an actual miscompare occurs. * Add fast path in NearComparator#ExpectLiteralsNear() for case of matching layouts, being careful not to compute multi index unless mismatch actually occurs. This optimized NearComparator#ExpectLiteralsNear() for the case of many element literals, with few miscompares. For many miscompares, we cannot avoid calculating multi indexes, but can fast path for equal layouts. For zero miscompares, we can at least fast path in the case of matching layouts. Before this CL, a run of THE ISOLATOR for a single literal with >120 million elements and a few miscompares took 379s (6.3m). With this CL, the same test case now takes 44s. Beautiful flame graphs omitted from public commit message, regrettably. PiperOrigin-RevId: 183451138
-rw-r--r--tensorflow/compiler/xla/literal_util.h1
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc186
2 files changed, 138 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index e0196509a7..2b68b8f177 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -486,6 +486,7 @@ class Literal {
std::vector<std::unique_ptr<Literal>> elements);
// Returns a string representation of the literal value.
+ // Warning: this function can take minutes for multi-million element Literals.
string ToString(bool print_layout = false) const;
// Invokes the "per cell" callback for each element in the provided
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index f8205de702..39c07297d6 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -355,9 +355,9 @@ class NearComparator {
// temporary files on failure. Returns true if literals match.
bool ExpectNear(const Literal& expected, const Literal& actual) {
VLOG(1) << "expected:";
- XLA_VLOG_LINES(1, expected.ToString());
+ XLA_VLOG_LINES(1, TruncateHugeLiteral(expected));
VLOG(1) << "actual:";
- XLA_VLOG_LINES(1, actual.ToString());
+ XLA_VLOG_LINES(1, TruncateHugeLiteral(actual));
// If the shapes mismatch, we simply fail the expectation instead of
// printing out data, as it's a type error rather than a value error.
@@ -377,6 +377,7 @@ class NearComparator {
max_rel_err_ = 0.0;
max_abs_err_ = 0.0;
miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED));
+ miscompares_.PopulateWithValue(false);
multi_index_.resize(expected.shape().dimensions_size(), 0);
switch (expected.shape().element_type()) {
@@ -404,21 +405,33 @@ class NearComparator {
if (num_miscompares_ > 0) {
if (!VLOG_IS_ON(1)) {
LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape())
- << " " << expected.ToString();
+ << " " << TruncateHugeLiteral(expected);
LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape())
- << " " << actual.ToString();
+ << " " << TruncateHugeLiteral(actual);
+ LOG(INFO) << "Dumping literals to temp files...";
+ WriteLiteralToTempFile(expected, "expected");
+ WriteLiteralToTempFile(actual, "actual");
+ WriteLiteralToTempFile(miscompares_, "miscompares");
}
EXPECT_TRUE(num_miscompares_ == 0)
<< "\nmax relative mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(max_rel_multi_index_)
+ << LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(
+ actual.shape(), max_rel_linear_index_))
<< "\nmaximum relative error " << max_rel_err_
<< "\nmax absolute mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(max_abs_multi_index_)
+ << LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(
+ actual.shape(), max_abs_linear_index_))
<< "\nmaximum absolute error " << max_abs_err_
<< "\nfirst mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(first_multi_index_)
+ << LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(
+ actual.shape(), first_linear_index_))
<< "\nlast mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(last_multi_index_)
+ << LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(
+ actual.shape(), last_linear_index_))
<< "\ntotal absolute error " << abs_diff_sum_
<< "\ntotal absolute error of miscompares "
<< abs_diff_miscompare_sum_ << "\ntotal relative error "
@@ -426,10 +439,6 @@ class NearComparator {
<< "\ntotal relative error of miscompares "
<< (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_)
<< "\nfailure count " << num_miscompares_;
-
- WriteLiteralToTempFile(expected, "expected");
- WriteLiteralToTempFile(actual, "actual");
- WriteLiteralToTempFile(miscompares_, "miscompares");
}
return num_miscompares_ == 0;
}
@@ -457,57 +466,93 @@ class NearComparator {
return true;
}
- float abs_diff = std::abs(actual - expected);
- float rel_err = abs_diff / std::abs(expected);
+ const float abs_diff = std::abs(actual - expected);
+ const float rel_err = abs_diff / std::abs(expected);
+ const bool nan_mismatch = NanMismatch<NativeT>(expected, actual);
+ const bool mismatch =
+ (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
+ return !mismatch;
+ }
+
+ // Assumes that expected vs actual fail ExpectValuesNear.
+ template <typename NativeT>
+ void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual,
+ const Shape& shape, const int64 linear_index) {
+ const float abs_diff = std::abs(actual - expected);
+ const float rel_err = abs_diff / std::abs(expected);
abs_diff_sum_ += abs_diff;
abs_expected_sum_ += std::abs(expected);
if (rel_err > max_rel_err_) {
max_rel_err_ = rel_err;
- max_rel_multi_index_ = multi_index_;
+ max_rel_linear_index_ = linear_index;
}
if (abs_diff > max_abs_err_) {
max_abs_err_ = abs_diff;
- max_abs_multi_index_ = multi_index_;
+ max_abs_linear_index_ = linear_index;
}
- VLOG(10) << tensorflow::strings::Printf(
- "index %s abs_diff %f rel_err %f",
- LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff,
- rel_err);
- bool nan_mismatch = NanMismatch<NativeT>(expected, actual);
- bool mismatch =
- (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
- if (mismatch) {
- abs_diff_miscompare_sum_ += abs_diff;
- abs_expected_miscompare_sum_ += std::abs(expected);
- const int64 kMaxFailures = 2;
- if (num_miscompares_ < kMaxFailures) {
- ::testing::Message msg;
- msg << "mismatch at index "
- << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff "
- << abs_diff << " rel err " << rel_err << " failure #"
- << num_miscompares_;
- ExpectNear<NativeT>(expected, actual, msg);
- } else if (num_miscompares_ == kMaxFailures) {
- LOG(ERROR)
- << "reached max 'loud' failure count; silently proceeding...";
- }
- if (num_miscompares_ == 0) {
- first_multi_index_ = multi_index_;
- }
- num_miscompares_++;
- last_multi_index_ = multi_index_;
+ if (VLOG_IS_ON(10)) {
+ VLOG(10) << tensorflow::strings::Printf(
+ "index %s abs_diff %f rel_err %f",
+ LiteralTestUtil::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape,
+ linear_index))
+ .c_str(),
+ abs_diff, rel_err);
}
- return !mismatch;
+ abs_diff_miscompare_sum_ += abs_diff;
+ abs_expected_miscompare_sum_ += std::abs(expected);
+ const int64 kMaxFailures = 2;
+ if (num_miscompares_ < kMaxFailures) {
+ const auto multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index);
+ ::testing::Message msg;
+ msg << "mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff "
+ << abs_diff << " rel err " << rel_err << " failure #"
+ << num_miscompares_;
+ ExpectNear<NativeT>(expected, actual, msg);
+ } else if (num_miscompares_ == kMaxFailures) {
+ LOG(ERROR) << "reached max 'loud' failure count; silently proceeding...";
+ }
+ if (num_miscompares_ == 0) {
+ first_linear_index_ = linear_index;
+ }
+ num_miscompares_++;
+ last_linear_index_ = linear_index;
+ miscompares_.data<bool>()[linear_index] = true;
}
// Recursive function which compares the two given literals elementwise.
template <typename NativeT>
void ExpectLiteralsNear(const Literal& expected, const Literal& actual,
int64 dimension) {
+ // Fast path optimization for the case were layouts match.
+ if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) {
+ tensorflow::gtl::ArraySlice<const NativeT> expected_data =
+ expected.data<NativeT>();
+ tensorflow::gtl::ArraySlice<const NativeT> actual_data =
+ actual.data<NativeT>();
+ const int64 len = expected_data.size();
+ for (int64 i = 0; i < len; ++i) {
+ const bool near = ExpectValuesNear(expected_data[i], actual_data[i]);
+ if (!near) {
+ UpdateAndLogMiscompares<NativeT>(expected_data[i], actual_data[i],
+ actual.shape(), i);
+ }
+ }
+ return;
+ }
+
if (dimension == expected.shape().dimensions_size()) {
bool near = ExpectValuesNear(expected.Get<NativeT>(multi_index_),
actual.Get<NativeT>(multi_index_));
- miscompares_.Set<bool>(multi_index_, !near);
+ if (!near) {
+ UpdateAndLogMiscompares<NativeT>(
+ expected.Get<NativeT>(multi_index_),
+ actual.Get<NativeT>(multi_index_), actual.shape(),
+ IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(),
+ multi_index_));
+ }
} else {
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
multi_index_[dimension] = i;
@@ -528,6 +573,32 @@ class NearComparator {
LOG(ERROR) << "wrote to " << name << " file: " << filename;
}
+ // Gets the total element count. For tuples, this is not the count of tuple
+ // elements, but the sum of elements of each tuple element.
+ int64 RecursiveElementCount(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
+ int64 total = 0;
+ for (int64 i = 0; i < tuple_elements; ++i) {
+ total +=
+ RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
+ }
+ return total;
+ } else {
+ return ShapeUtil::ElementsIn(shape);
+ }
+ }
+
+ // Calling ToString on a literal with over 100 million elements takes around
+ // 3 minutes. The utility of printing a literal with >1000 elements is
+ // questionable, especially when writing the Literal proto to disk is orders
+ // of magnitude faster.
+ string TruncateHugeLiteral(const Literal& literal) {
+ return RecursiveElementCount(literal.shape()) < 1000
+ ? literal.ToString()
+ : "[TRUNCATED, Literal with more than 1000 values]";
+ }
+
ErrorSpec error_;
// Number of element miscomparisons encountered so far.
@@ -548,10 +619,10 @@ class NearComparator {
double abs_expected_miscompare_sum_;
float max_rel_err_;
float max_abs_err_;
- std::vector<int64> first_multi_index_;
- std::vector<int64> last_multi_index_;
- std::vector<int64> max_rel_multi_index_;
- std::vector<int64> max_abs_multi_index_;
+ int64 first_linear_index_;
+ int64 last_linear_index_;
+ int64 max_rel_linear_index_;
+ int64 max_abs_linear_index_;
};
template <>
@@ -584,6 +655,23 @@ bool NearComparator::ExpectValuesNear<half>(half expected, half actual) {
static_cast<float>(std::move(actual)));
}
+template <>
+void NearComparator::UpdateAndLogMiscompares<bfloat16>(
+ const bfloat16 expected, const bfloat16 actual, const Shape& shape,
+ const int64 linear_index) {
+ UpdateAndLogMiscompares(static_cast<float>(expected),
+ static_cast<float>(actual), shape, linear_index);
+}
+
+template <>
+void NearComparator::UpdateAndLogMiscompares<half>(half expected, half actual,
+ const Shape& shape,
+ const int64 linear_index) {
+ UpdateAndLogMiscompares(static_cast<float>(std::move(expected)),
+ static_cast<float>(std::move(actual)), shape,
+ linear_index);
+}
+
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(