summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--absl/container/btree_test.cc25
-rw-r--r--absl/container/internal/btree.h41
2 files changed, 58 insertions, 8 deletions
diff --git a/absl/container/btree_test.cc b/absl/container/btree_test.cc
index 05bbcf2f..404ccde7 100644
--- a/absl/container/btree_test.cc
+++ b/absl/container/btree_test.cc
@@ -3358,6 +3358,31 @@ TEST(Btree, DereferencingEndIterator) {
EXPECT_DEATH(*set.end(), R"regex(Dereferencing end\(\) iterator)regex");
}
+TEST(Btree, InvalidIteratorComparison) {
+ if (!IsAssertEnabled()) GTEST_SKIP() << "Assertions not enabled.";
+
+ absl::btree_set<int> set1, set2;
+ for (int i = 0; i < 1000; ++i) {
+ set1.insert(i);
+ set2.insert(i);
+ }
+
+ constexpr const char *kValueInitDeathMessage =
+ "Comparing default-constructed iterator with .*non-default-constructed "
+ "iterator";
+ typename absl::btree_set<int>::iterator iter1, iter2;
+ EXPECT_EQ(iter1, iter2);
+ EXPECT_DEATH(void(set1.begin() == iter1), kValueInitDeathMessage);
+ EXPECT_DEATH(void(iter1 == set1.begin()), kValueInitDeathMessage);
+
+ constexpr const char *kDifferentContainerDeathMessage =
+ "Comparing iterators from different containers";
+ iter1 = set1.begin();
+ iter2 = set2.begin();
+ EXPECT_DEATH(void(iter1 == iter2), kDifferentContainerDeathMessage);
+ EXPECT_DEATH(void(iter2 == iter1), kDifferentContainerDeathMessage);
+}
+
} // namespace
} // namespace container_internal
ABSL_NAMESPACE_END
diff --git a/absl/container/internal/btree.h b/absl/container/internal/btree.h
index 2e21dc66..ab75afb4 100644
--- a/absl/container/internal/btree.h
+++ b/absl/container/internal/btree.h
@@ -1017,6 +1017,17 @@ class btree_node {
friend struct btree_access;
};
+template <typename Node>
+bool AreNodesFromSameContainer(const Node *node_a, const Node *node_b) {
+ // If either node is null, then give up on checking whether they're from the
+ // same container. (If exactly one is null, then we'll trigger the
+ // default-constructed assert in Equals.)
+ if (node_a == nullptr || node_b == nullptr) return true;
+ while (!node_a->is_root()) node_a = node_a->parent();
+ while (!node_b->is_root()) node_b = node_b->parent();
+ return node_a == node_b;
+}
+
template <typename Node, typename Reference, typename Pointer>
class btree_iterator {
using field_type = typename Node::field_type;
@@ -1073,16 +1084,16 @@ class btree_iterator {
}
bool operator==(const iterator &other) const {
- return node_ == other.node_ && position_ == other.position_;
+ return Equals(other.node_, other.position_);
}
bool operator==(const const_iterator &other) const {
- return node_ == other.node_ && position_ == other.position_;
+ return Equals(other.node_, other.position_);
}
bool operator!=(const iterator &other) const {
- return node_ != other.node_ || position_ != other.position_;
+ return !Equals(other.node_, other.position_);
}
bool operator!=(const const_iterator &other) const {
- return node_ != other.node_ || position_ != other.position_;
+ return !Equals(other.node_, other.position_);
}
// Returns n such that n calls to ++other yields *this.
@@ -1161,13 +1172,27 @@ class btree_iterator {
#endif
}
+ bool Equals(const node_type *other_node, int other_position) const {
+ ABSL_HARDENING_ASSERT(((node_ == nullptr && other_node == nullptr) ||
+ (node_ != nullptr && other_node != nullptr)) &&
+ "Comparing default-constructed iterator with "
+ "non-default-constructed iterator.");
+ // Note: we use assert instead of ABSL_HARDENING_ASSERT here because this
+ // changes the complexity of Equals from O(1) to O(log(N) + log(M)) where
+ // N/M are sizes of the containers containing node_/other_node.
+ assert(AreNodesFromSameContainer(node_, other_node) &&
+ "Comparing iterators from different containers.");
+ return node_ == other_node && position_ == other_position;
+ }
+
bool IsEndIterator() const {
if (position_ != node_->finish()) return false;
- // Navigate to the rightmost node.
node_type *node = node_;
- while (!node->is_root()) node = node->parent();
- while (node->is_internal()) node = node->child(node->finish());
- return node == node_;
+ while (!node->is_root()) {
+ if (node->position() != node->parent()->finish()) return false;
+ node = node->parent();
+ }
+ return true;
}
// Returns n such that n calls to ++other yields *this.