aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-08-30 13:59:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 14:03:08 -0700
commitfeffc075befac53dddc721572493796c8fbffe3c (patch)
tree437de0f7e0dabfbea9782ac0c074d57ed2c1c75f
parentf207dd8964be31ee33e733367f1c9b7325479482 (diff)
[XLA] xla::ContainersEqual -> absl::c_equal
The replacement for the initializer_list overload is a bit sad because MakeSpan doesn't understand initializer_list (and we don't have CTAD yet) PiperOrigin-RevId: 210974939
-rw-r--r--tensorflow/compiler/xla/service/cpu/shape_partition_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc6
-rw-r--r--tensorflow/compiler/xla/shape_util.cc30
-rw-r--r--tensorflow/compiler/xla/util.h27
-rw-r--r--tensorflow/compiler/xla/util_test.cc43
8 files changed, 42 insertions, 110 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
index ae80a6f497..7d8e51f909 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
@@ -102,22 +102,22 @@ TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) {
{
ShapePartitionIterator iterator(shape, {1});
EXPECT_EQ(1, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 5}}), iterator.GetPartition(0)));
}
{
ShapePartitionIterator iterator(shape, {2});
EXPECT_EQ(2, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0)));
- EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 2}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(1)));
}
{
ShapePartitionIterator iterator(shape, {3});
EXPECT_EQ(3, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0)));
- EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1)));
- EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 1}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{1, 1}}), iterator.GetPartition(1)));
+ EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(2)));
}
}
@@ -128,20 +128,20 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
ShapePartitionIterator iterator(shape, {1, 1});
EXPECT_EQ(1, iterator.GetTotalPartitionCount());
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0)));
+ absl::c_equal(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0)));
}
{
ShapePartitionIterator iterator(shape, {2, 2});
EXPECT_EQ(4, iterator.GetTotalPartitionCount());
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0)));
+ absl::c_equal(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0)));
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1)));
+ absl::c_equal(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1)));
EXPECT_TRUE(
- ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2)));
+ absl::c_equal(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2)));
EXPECT_TRUE(
- ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3)));
+ absl::c_equal(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3)));
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index c0c8ae181a..860dd0b50f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -3295,7 +3295,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
if (!reduced_dims_021.has_value()) {
reduced_dims_021 = curr_reduced_dims_021;
}
- if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) {
+ if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
// There is more than one possible transpose. Instead of picking one
// transpose, we simply give up here.
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 4a424cebc0..f3fd287d88 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -867,8 +867,8 @@ class HloInstruction {
return false;
}
- if (!ContainersEqual(precision_config_.operand_precision(),
- other.precision_config_.operand_precision())) {
+ if (!absl::c_equal(precision_config_.operand_precision(),
+ other.precision_config_.operand_precision())) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 0b7f741d73..e1c884d856 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -337,11 +337,10 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
/*eq_computations*/) const {
const auto& casted_other =
static_cast<const HloCollectiveInstruction&>(other);
- return ContainersEqual(replica_groups(), casted_other.replica_groups(),
- [](const ReplicaGroup& a, const ReplicaGroup& b) {
- return ContainersEqual(a.replica_ids(),
- b.replica_ids());
- });
+ return absl::c_equal(replica_groups(), casted_other.replica_groups(),
+ [](const ReplicaGroup& a, const ReplicaGroup& b) {
+ return absl::c_equal(a.replica_ids(), b.replica_ids());
+ });
}
HloAllReduceInstruction::HloAllReduceInstruction(
@@ -452,11 +451,10 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath(
/*eq_computations*/) const {
const auto& casted_other =
static_cast<const HloCollectivePermuteInstruction&>(other);
- return ContainersEqual(
- source_target_pairs(), casted_other.source_target_pairs(),
- [](const std::pair<int64, int64>& a, const std::pair<int64, int64>& b) {
- return a == b;
- });
+ return absl::c_equal(source_target_pairs(),
+ casted_other.source_target_pairs(),
+ [](const std::pair<int64, int64>& a,
+ const std::pair<int64, int64>& b) { return a == b; });
}
std::unique_ptr<HloInstruction>
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 3f3cb2fa54..744cd64bc5 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1068,9 +1068,9 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RET_CHECK(instruction->parent() == computation);
if (instruction->opcode() == HloOpcode::kFusion) {
TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction));
- TF_RET_CHECK(
- ContainersEqual(instruction->called_computations(),
- {instruction->fused_instructions_computation()}))
+ TF_RET_CHECK(instruction->called_computations() ==
+ absl::Span<HloComputation* const>(
+ {instruction->fused_instructions_computation()}))
<< "Fusion HLO calls computations other than the "
"fused_instructions_computation: "
<< instruction->ToString()
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 763fd3525b..6d016abfde 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -95,11 +95,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
}
if (ShapeUtil::IsTuple(lhs)) {
- return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- [=](const Shape& l, const Shape& r) {
- return CompareShapes(l, r, compare_layouts,
- ignore_fp_precision);
- });
+ return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ [=](const Shape& l, const Shape& r) {
+ return CompareShapes(l, r, compare_layouts,
+ ignore_fp_precision);
+ });
} else if (!ShapeUtil::IsArray(lhs)) {
// Non-tuple, non-array tupes such as opaque and token types are trivially
// the same.
@@ -111,13 +111,13 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
return false;
}
if (LayoutUtil::IsDenseArray(lhs)) {
- if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs),
- LayoutUtil::MinorToMajor(rhs))) {
+ if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs),
+ LayoutUtil::MinorToMajor(rhs))) {
VLOG(3) << "CompareShapes: lhs layout != rhs layout";
return false;
}
- if (!ContainersEqual(lhs.layout().padded_dimensions(),
- rhs.layout().padded_dimensions())) {
+ if (!absl::c_equal(lhs.layout().padded_dimensions(),
+ rhs.layout().padded_dimensions())) {
VLOG(3)
<< "CompareShapes: lhs padded_dimensions != rhs padded_dimensions";
return false;
@@ -662,7 +662,7 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
const Shape& rhs) {
CHECK(ShapeUtil::IsArray(lhs));
CHECK(ShapeUtil::IsArray(rhs));
- return ContainersEqual(lhs.dimensions(), rhs.dimensions());
+ return absl::c_equal(lhs.dimensions(), rhs.dimensions());
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
@@ -676,8 +676,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
return IsArray(rhs) && SameDimensions(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- CompatibleIgnoringElementType);
+ absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ CompatibleIgnoringElementType);
} else {
// Opaque, token, etc types are vacuously compatible.
return lhs.element_type() == rhs.element_type();
@@ -691,8 +691,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
CompatibleIgnoringElementType(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- CompatibleIgnoringFpPrecision);
+ absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ CompatibleIgnoringFpPrecision);
} else {
// Opaque, token, etc types are vacuously compatible.
return lhs.element_type() == rhs.element_type();
@@ -1286,7 +1286,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
// apply(input_dimensions, I) =
// apply((dimension_mapping * output_dimensions), I)
// input_dimensions = dimension_mapping * output_dimensions
- return ContainersEqual(
+ return absl::c_equal(
ComposePermutations(dimension_mapping,
AsInt64Slice(output_shape.layout().minor_to_major())),
input_shape.layout().minor_to_major());
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 35a3c7db32..c8b48c5ab4 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -162,33 +162,6 @@ static inline tensorflow::gtl::ArraySlice<uint64> AsUInt64Slice(
reinterpret_cast<const uint64*>(slice.data()), slice.size());
}
-// Compares two containers for equality. Returns true iff the two containers
-// have the same size and all their elements compare equal using their
-// operator==. Like std::equal, but forces size equality.
-template <typename Container1T, typename Container2T>
-bool ContainersEqual(const Container1T& c1, const Container2T& c2) {
- return ((c1.size() == c2.size()) &&
- std::equal(std::begin(c1), std::end(c1), std::begin(c2)));
-}
-
-template <typename Container1T,
- typename ElementType = typename Container1T::value_type>
-bool ContainersEqual(const Container1T& c1,
- std::initializer_list<ElementType> il) {
- tensorflow::gtl::ArraySlice<ElementType> c2{il};
- return ContainersEqual(c1, c2);
-}
-
-// Compares two containers for equality. Returns true iff the two containers
-// have the same size and all their elements compare equal using the predicate
-// p. Like std::equal, but forces size equality.
-template <typename Container1T, typename Container2T, class PredicateT>
-bool ContainersEqual(const Container1T& c1, const Container2T& c2,
- PredicateT p) {
- return ((c1.size() == c2.size()) &&
- std::equal(std::begin(c1), std::end(c1), std::begin(c2), p));
-}
-
// Performs a copy of count values from src to dest, using different strides for
// source and destination. The source starting index is src_base, while the
// destination one is dest_base.
diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc
index 288479c893..50a3c545fb 100644
--- a/tensorflow/compiler/xla/util_test.cc
+++ b/tensorflow/compiler/xla/util_test.cc
@@ -37,45 +37,6 @@ TEST(UtilTest, ReindentsDifferentNumberOfLeadingSpacesUniformly) {
EXPECT_EQ(want, got);
}
-// Some smoke tests for ContainersEqual. Keeping it simple since these are just
-// basic wrappers around std::equal.
-TEST(UtilTest, ContainersEqualDefault) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::vector<int> c2 = {1, 2, 3};
- std::vector<int> c3 = {};
- std::vector<int> c4 = {1, 2, 3, 4};
- std::vector<int> c5 = {1, 2, 3, 4, 5};
- std::vector<int> c6 = {1, 3, 4, 5};
-
- EXPECT_TRUE(ContainersEqual(c1, c4));
- EXPECT_TRUE(ContainersEqual(c4, c1));
- EXPECT_FALSE(ContainersEqual(c1, c2));
- EXPECT_FALSE(ContainersEqual(c2, c1));
- EXPECT_FALSE(ContainersEqual(c1, c3));
- EXPECT_FALSE(ContainersEqual(c3, c1));
- EXPECT_FALSE(ContainersEqual(c1, c5));
- EXPECT_FALSE(ContainersEqual(c5, c1));
- EXPECT_FALSE(ContainersEqual(c1, c6));
- EXPECT_FALSE(ContainersEqual(c6, c1));
-}
-
-TEST(UtilTest, ContainersEqualPredicate) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::vector<int> c2 = {10, 20, 30, 40};
-
- EXPECT_TRUE(ContainersEqual(
- c1, c2, [](const int& i1, const int& i2) { return i1 < i2; }));
- EXPECT_FALSE(ContainersEqual(
- c1, c2, [](const int& i1, const int& i2) { return i1 > i2; }));
-}
-
-TEST(UtilTest, ContainersEqualDifferentContainerTypes) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::list<int> c2 = {1, 2, 3, 4};
-
- EXPECT_TRUE(ContainersEqual(c1, c2));
-}
-
TEST(UtilTest, HumanReadableNumFlopsExample) {
ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
}
@@ -117,8 +78,8 @@ TEST(UtilTest, CommonFactors) {
/*.expected =*/{{0, 0}, {0, 1}, {2, 2}, {3, 2}, {4, 3}, {4, 4}}},
};
for (const auto& test_case : test_cases) {
- EXPECT_TRUE(ContainersEqual(test_case.expected,
- CommonFactors(test_case.a, test_case.b)));
+ EXPECT_TRUE(absl::c_equal(test_case.expected,
+ CommonFactors(test_case.a, test_case.b)));
}
}