aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-08-23 16:06:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 16:10:33 -0700
commit3ed1f31b1a67b40fcc2450f89c7c4652dd837a5f (patch)
tree4337202a3f5d6320202336c70e1356f7f849be7e
parent3dd67a1e399505a297bbdd58440b06855c92a35d (diff)
Clean up Literal comparision mismatch messages.
Make the error messages more consistent between "Near" comparisons and "Equal" comparisons, and give Shape index for mismatches in tuples. Also, polish flag descriptions and test output in run_hlo_module. ==== BEFORE: Mismatches in shape (s32[], s32[], s32[]) (3 elements): Expected equality of these values: 470211269 0 at index {}: expected: 470211269 actual: 0 ==== AFTER: Mismatches in shape (s32[], s32[], s32[]) (3 elements): Array at shape index {1}, first mismatch at array index {}: expected value: 470211269 actual value: 0 Expected literal: (s32[], s32[], s32[]) ( -1865008400, 470211269, 470211269 ) Actual literal: (s32[], s32[], s32[]) ( -1865008400, 0, 470211269 ) PiperOrigin-RevId: 210017733
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc185
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc6
2 files changed, 105 insertions, 86 deletions
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index a36e81bc90..f1891335e8 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -47,7 +47,7 @@ Status CompareFloatsBitwiseEqual(
if (ulhs != urhs) {
return InvalidArgument(
"floating values are not bitwise-equal; and equality testing "
- "was requested: %s=%g=%a vs %s=%g=%a at index %s",
+ "was requested: %s=%g=%a vs %s=%g=%a at array index %s",
StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double,
LiteralUtil::MultiIndexAsString(multi_index).c_str());
@@ -65,9 +65,10 @@ Status CompareEqual(NativeT lhs, NativeT rhs,
return Status::OK();
}
return InvalidArgument(
- "Expected equality of these values:\n %s\n %s\nat index %s",
- StrCat(lhs).c_str(), StrCat(rhs).c_str(),
- LiteralUtil::MultiIndexAsString(multi_index).c_str());
+ "first mismatch at array index %s:\n expected value: %s\n actual "
+ "value: %s",
+ LiteralUtil::MultiIndexAsString(multi_index).c_str(), StrCat(lhs).c_str(),
+ StrCat(rhs).c_str());
}
// Specializations for floating types that do bitwise comparisons when equality
@@ -119,7 +120,8 @@ Status Equal(LiteralSlice expected, LiteralSlice actual,
Status result;
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
multi_index[dimension] = i;
- result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1));
+ TF_RETURN_IF_ERROR(
+ Equal<NativeT>(expected, actual, multi_index, dimension + 1));
}
return result;
}
@@ -251,11 +253,6 @@ class NearComparator {
// Runs the comparison between expected and actual literals.
Status Run() {
- VLOG(1) << "expected:";
- XLA_VLOG_LINES(1, ToStringTruncated(expected_));
- VLOG(1) << "actual:";
- XLA_VLOG_LINES(1, ToStringTruncated(actual_));
-
// If the shapes mismatch, we simply fail the expectation instead of
// printing out data, as it's a type error rather than a value error.
TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
@@ -539,6 +536,62 @@ constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
template <typename NativeT>
constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
+Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
+ TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
+ std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ Status result;
+ switch (expected.shape().element_type()) {
+ case PRED:
+ result = Equal<bool>(expected, actual, &multi_index, 0);
+ break;
+ case U8:
+ result = Equal<uint8>(expected, actual, &multi_index, 0);
+ break;
+ case S32:
+ result = Equal<int32>(expected, actual, &multi_index, 0);
+ break;
+ case S64:
+ result = Equal<int64>(expected, actual, &multi_index, 0);
+ break;
+ case U32:
+ result = Equal<uint32>(expected, actual, &multi_index, 0);
+ break;
+ case U64:
+ result = Equal<uint64>(expected, actual, &multi_index, 0);
+ break;
+ case BF16:
+ result = Equal<bfloat16>(expected, actual, &multi_index, 0);
+ break;
+ case F16:
+ result = Equal<half>(expected, actual, &multi_index, 0);
+ break;
+ case F32:
+ result = Equal<float>(expected, actual, &multi_index, 0);
+ break;
+ case F64:
+ result = Equal<double>(expected, actual, &multi_index, 0);
+ break;
+ case C64:
+ result = Equal<complex64>(expected, actual, &multi_index, 0);
+ break;
+ case TUPLE: {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ result.Update(EqualHelper(LiteralSlice(expected, {i}),
+ LiteralSlice(actual, {i})));
+ }
+ break;
+ }
+ case TOKEN:
+ // Tokens have no on-device representation and are trivially equal.
+ return Status::OK();
+ default:
+ LOG(FATAL) << "Unsupported primitive type: "
+ << PrimitiveType_Name(expected.shape().element_type());
+ }
+
+ return result;
+}
+
// Helper function for comparing two literals for nearness. Handles tuple-shapes
// via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared.
@@ -555,17 +608,18 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
const auto actual_element = LiteralSlice(actual, {i});
ShapeIndex element_index = shape_index;
element_index.push_back(i);
- Status res =
+ Status element_result =
NearHelper(expected_element, actual_element, error, detailed_message,
miscompare_callback, element_index);
- if (!res.ok()) {
- string err_message = Printf("\nArray at shape index %s%s",
- element_index.ToString().c_str(),
- res.error_message().c_str());
+ if (!element_result.ok()) {
+ element_result = InvalidArgument(
+ "Array at shape index %s, %s", element_index.ToString().c_str(),
+ element_result.error_message().c_str());
if (return_status.ok()) {
- return_status = res;
+ return_status = element_result;
} else {
- return_status = AppendStatus(return_status, res.error_message());
+ return_status =
+ AppendStatus(return_status, element_result.error_message());
}
}
}
@@ -611,8 +665,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
}
}
- // Non-floating point literal.
- return literal_comparison::Equal(expected, actual);
+ // Non-floating point, non-tuple literal.
+ return EqualHelper(expected, actual);
}
} // namespace
@@ -668,81 +722,44 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
return Status::OK();
}
+namespace {
+
+// If result is an error, extend the error message with the expected and actual
+// literals.
+Status EmitLiteralsInErrorMessage(const Status& result,
+ const LiteralSlice& expected,
+ const LiteralSlice& actual) {
+ if (result.ok()) {
+ return result;
+ }
+ return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
+ result.error_message().c_str(),
+ ToStringTruncated(expected).c_str(),
+ ToStringTruncated(actual).c_str());
+}
+
+} // namespace
+
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
VLOG(1) << "expected:";
XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
XLA_VLOG_LINES(1, actual.ToString());
-
- TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
- std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
- Status result;
- switch (expected.shape().element_type()) {
- case PRED:
- result = Equal<bool>(expected, actual, &multi_index, 0);
- break;
- case U8:
- result = Equal<uint8>(expected, actual, &multi_index, 0);
- break;
- case S32:
- result = Equal<int32>(expected, actual, &multi_index, 0);
- break;
- case S64:
- result = Equal<int64>(expected, actual, &multi_index, 0);
- break;
- case U32:
- result = Equal<uint32>(expected, actual, &multi_index, 0);
- break;
- case U64:
- result = Equal<uint64>(expected, actual, &multi_index, 0);
- break;
- case BF16:
- result = Equal<bfloat16>(expected, actual, &multi_index, 0);
- break;
- case F16:
- result = Equal<half>(expected, actual, &multi_index, 0);
- break;
- case F32:
- result = Equal<float>(expected, actual, &multi_index, 0);
- break;
- case F64:
- result = Equal<double>(expected, actual, &multi_index, 0);
- break;
- case C64:
- result = Equal<complex64>(expected, actual, &multi_index, 0);
- break;
- case TUPLE: {
- for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- result.Update(
- Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})));
- }
- break;
- }
- case TOKEN:
- // Tokens have no on-device representation and are trivially equal.
- return Status::OK();
- default:
- LOG(FATAL)
- << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
- << PrimitiveType_Name(expected.shape().element_type());
- }
-
- if (result.ok()) {
- return Status::OK();
- }
-
- return AppendStatus(
- result, tensorflow::strings::Printf("\nexpected: %s\nactual: %s",
- ToStringTruncated(expected).c_str(),
- ToStringTruncated(actual).c_str()));
+ Status result = EqualHelper(expected, actual);
+ return EmitLiteralsInErrorMessage(result, expected, actual);
}
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message,
const MiscompareCallback& miscompare_callback) {
- return NearHelper(expected, actual, error, detailed_message,
- miscompare_callback,
- /*shape_index=*/{});
+ VLOG(1) << "Expected literal:";
+ XLA_VLOG_LINES(1, expected.ToString());
+ VLOG(1) << "Actual literal:";
+ XLA_VLOG_LINES(1, actual.ToString());
+ Status result =
+ NearHelper(expected, actual, error, detailed_message, miscompare_callback,
+ /*shape_index=*/{});
+ return EmitLiteralsInErrorMessage(result, expected, actual);
}
string ToStringTruncated(const LiteralSlice& literal) {
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index f297b2b847..d481fdfee3 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -105,8 +105,10 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
::testing::AssertionResult result =
LiteralTestUtil::Equal(*expected, *actual);
- EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
- EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}"));
+ EXPECT_THAT(result.message(),
+ ::testing::HasSubstr("Expected literal:\n{1, 2, 3}"));
+ EXPECT_THAT(result.message(),
+ ::testing::HasSubstr("Actual literal:\n{4, 5, 6}"));
}
TEST(LiteralTestUtilTest, NearComparatorR1) {