aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-10-01 19:42:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 19:46:29 -0700
commit991f06fd50fc73285ce415d57f720994c2b2e861 (patch)
tree0b87402386aad22ec958f171bfd57f9c7c3e8571
parentbeede8525be5386451bf0098992c37416d1864db (diff)
[XLA] Migrate from gtl::FlatSet to absl::flat_hash_set
PiperOrigin-RevId: 215324035
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc10
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc7
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc10
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc6
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc7
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.cc4
-rw-r--r--tensorflow/compiler/tests/BUILD1
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc14
-rw-r--r--tensorflow/compiler/xla/client/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc4
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h4
-rw-r--r--tensorflow/compiler/xla/service/BUILD27
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc9
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h11
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc48
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h22
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.h4
-rw-r--r--tensorflow/compiler/xla/service/buffer_value_containers.h4
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc9
-rw-r--r--tensorflow/compiler/xla/service/call_graph.h10
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc6
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc13
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc4
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h7
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc6
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h4
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc4
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc4
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc8
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc3
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc5
63 files changed, 235 insertions, 186 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index f20270931f..661b444a42 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -325,6 +325,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -407,6 +408,7 @@ cc_library(
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index e63d4b7792..e0b9932d80 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -16,11 +16,11 @@ limitations under the License.
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
// ALGORITHM OVERVIEW
@@ -298,7 +298,7 @@ class SymbolPredicate : public Predicate {
template <typename FunctionTy>
/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
- gtl::FlatSet<Predicate*> visited;
+ absl::flat_hash_set<Predicate*> visited;
std::vector<Predicate*> stack;
stack.push_back(p);
@@ -467,7 +467,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
Predicate::Kind other_pred_kind =
is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
- gtl::FlatSet<Predicate*> simplified_ops_set;
+ absl::flat_hash_set<Predicate*> simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
// Simplify A&A => A and A|A => A.
@@ -492,7 +492,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
}
// Simplify "A&~A=>False" and "A|~A=>True".
- gtl::FlatSet<Predicate*> negated_ops;
+ absl::flat_hash_set<Predicate*> negated_ops;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kNot) {
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
@@ -512,7 +512,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
//
// First find any predicates contained in all subops.
std::vector<Predicate*> common_inner_operands;
- gtl::FlatSet<Predicate*> common_inner_operands_set;
+ absl::flat_hash_set<Predicate*> common_inner_operands_set;
for (Predicate* op : simplified_ops) {
if (op->kind() != other_pred_kind) {
common_inner_operands.clear();
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index d165341f21..da27f837e8 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/public/session_options.h"
@@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) {
namespace {
bool AreAllParentsGuaranteedConst(
- const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
+ const Node& n,
+ const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
if (n.type_string() == "GuaranteeConst") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
@@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst(
void MarkGuaranteedConstants(
const Graph& graph,
const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
- gtl::FlatSet<const Node*> guaranteed_const_nodes;
+ absl::flat_hash_set<const Node*> guaranteed_const_nodes;
std::vector<const Node*> srcs;
srcs.reserve(src_arg_pairs.size());
for (const auto& src_arg : src_arg_pairs) {
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 755c364c62..2ce6fa73fc 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -15,13 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) {
}
// Adds the control inputs of `node` to `*deps`.
-void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+void AddControlInputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
for (const Edge* edge : node.in_edges()) {
if (edge->IsControlEdge()) {
deps->insert(edge->src());
@@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
}
// Adds the control outputs of `node` to `*deps`.
-void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+void AddControlOutputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
for (const Edge* edge : node.out_edges()) {
if (edge->IsControlEdge()) {
deps->insert(edge->dst());
@@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Data and control inputs to the new XlaLaunch node.
std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
- gtl::FlatSet<Node*> control_inputs;
+ absl::flat_hash_set<Node*> control_inputs;
DataTypeVector arg_types(num_args);
AddControlInputs(*launch, &control_inputs);
@@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Outputs.
const int num_outputs = launch->output_types().size();
- gtl::FlatSet<Node*> control_outputs;
+ absl::flat_hash_set<Node*> control_outputs;
std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
DataTypeVector output_types(num_outputs);
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 133d982360..4f0c370e65 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
@@ -371,7 +371,7 @@ bool IsXlaFusable(const NodeDef& node) {
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
- OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) {
+ OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
@@ -849,7 +849,7 @@ Status MarkForCompilationPass::RunImpl(
Graph* graph = options.graph->get();
OrderedNodeSet compilation_candidates;
- gtl::FlatSet<Node*> isolated_nodes;
+ absl::flat_hash_set<Node*> isolated_nodes;
TF_RETURN_IF_ERROR(FindCompilationCandidates(
*graph, options.flib_def,
(options.session_options != nullptr) ? options.session_options->env
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 10fc9e85d9..b1f9e9088f 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -15,17 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace tensorflow {
namespace {
-Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
+Status FindNodesToDecluster(const Graph& graph,
+ absl::flat_hash_set<Node*>* result,
absl::Span<Node* const> post_order) {
// Find nodes that have at least one user outside their cluster that expects
// hostmem output. These nodes should be cloned to outside the cluster to
@@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/NotBackedge);
- gtl::FlatSet<Node*> nodes_to_partially_decluster;
+ absl::flat_hash_set<Node*> nodes_to_partially_decluster;
TF_RETURN_IF_ERROR(
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
index 657bb409db..e039d46ec8 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -82,6 +82,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
@@ -89,7 +90,6 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/util/ptr_util.h"
@@ -176,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) {
// point.
class ResourceOpSet {
private:
- using Impl = gtl::FlatSet<ResourceOp>;
+ using Impl = absl::flat_hash_set<ResourceOp>;
public:
ResourceOpSet() = default;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 3cf74fa788..822fedf121 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1105,6 +1105,7 @@ cc_library(
"//tensorflow/core:test",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index bddda6f302..7a96f4c25c 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -45,6 +45,7 @@ limitations under the License.
#include <random>
#include <unordered_map>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/defs.h"
@@ -63,7 +64,6 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
@@ -457,7 +457,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
- gtl::FlatSet<float> already_generated;
+ absl::flat_hash_set<float> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<float>(&tensor, [&](int i) -> float {
float generated;
@@ -470,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_DOUBLE: {
- gtl::FlatSet<double> already_generated;
+ absl::flat_hash_set<double> already_generated;
std::uniform_real_distribution<double> distribution(-1.0, 1.0);
test::FillFn<double>(&tensor, [&](int i) -> double {
double generated;
@@ -483,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_COMPLEX64: {
- gtl::FlatSet<std::pair<float, float>> already_generated;
+ absl::flat_hash_set<std::pair<float, float>> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<complex64>(&tensor, [&](int i) {
complex64 generated;
@@ -500,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_INT32: {
- gtl::FlatSet<int32> already_generated;
+ absl::flat_hash_set<int32> already_generated;
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
test::FillFn<int32>(&tensor, [&](int i) -> int32 {
int32 generated;
@@ -513,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_INT64: {
- gtl::FlatSet<int64> already_generated;
+ absl::flat_hash_set<int64> already_generated;
std::uniform_int_distribution<int64> distribution(-(1LL << 40),
1LL << 40);
test::FillFn<int64>(&tensor, [&](int i) -> int64 {
@@ -527,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_BOOL: {
- gtl::FlatSet<bool> already_generated;
+ absl::flat_hash_set<bool> already_generated;
std::bernoulli_distribution distribution;
test::FillFn<bool>(&tensor, [&](int i) -> bool {
bool generated;
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 1191cff109..dc097f3696 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -221,6 +221,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 5277de6a85..e0ec91dba1 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
@@ -2290,7 +2290,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
// also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64> related_ops;
- tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
+ absl::flat_hash_set<int64> related_calls; // Related computations.
std::queue<int64> worklist;
worklist.push(root->id());
related_ops.insert(root->id());
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index b7295e8a53..cd0d5ca5d3 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/padding.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"
@@ -1035,7 +1035,7 @@ class XlaBuilder {
std::map<int64, HloComputationProto> embedded_;
// The unique parameter numbers.
- tensorflow::gtl::FlatSet<int64> parameter_numbers_;
+ absl::flat_hash_set<int64> parameter_numbers_;
// The metadata to attach to each op. This is structured as a "modal"-like
// operation, in order to simplify client code (and not sprinkle this metadata
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 8da6364786..13803f5ebe 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -147,6 +147,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -183,6 +184,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
@@ -336,6 +338,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -490,6 +493,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -781,6 +785,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -959,6 +964,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -995,6 +1001,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
@@ -1043,6 +1050,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -1136,6 +1144,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
],
)
@@ -1230,6 +1239,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -1275,6 +1285,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -1348,6 +1359,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -1660,6 +1672,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
@@ -2064,6 +2077,7 @@ cc_library(
":logical_buffer",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -2099,6 +2113,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -2120,6 +2135,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -2203,6 +2219,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -2225,6 +2242,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@@ -2286,6 +2304,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -2343,6 +2362,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -2370,6 +2390,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -2487,6 +2508,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -2616,6 +2638,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -2655,6 +2678,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -2730,6 +2754,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
],
)
@@ -3300,6 +3325,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -3387,6 +3413,7 @@ cc_library(
"//tensorflow/core:ptr_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 58f78f8e24..002be9c970 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
};
auto root = fusion->fused_instructions_computation()->root_instruction();
- tensorflow::gtl::FlatSet<const HloValue*> changed_root_buffers;
+ absl::flat_hash_set<const HloValue*> changed_root_buffers;
auto root_changes_it = changes_to_bf16_.find(root);
if (root_changes_it != changes_to_bf16_.end()) {
@@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
HloComputation* computation,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations) {
+ absl::flat_hash_set<const HloComputation*>* visited_computations) {
bool parameter_changed = false;
auto insts = computation->MakeInstructionPostOrder();
// Do the adjustment on each instruction in the computation in reverse
@@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
// another input parameter. A fixed point will be reached because the
// parameters can only be changed from BF16 to F32, not the other way
// around.
- tensorflow::gtl::FlatSet<const HloComputation*> visited_in_while;
+ absl::flat_hash_set<const HloComputation*> visited_in_while;
while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
&visited_in_while) ||
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
@@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
HloModule* module) {
const auto& computations_topological_order =
module->MakeComputationPostOrder();
- tensorflow::gtl::FlatSet<const HloComputation*> resolved;
+ absl::flat_hash_set<const HloComputation*> resolved;
for (auto comp_it = computations_topological_order.rbegin();
comp_it != computations_topological_order.rend(); ++comp_it) {
if (ContainsKey(resolved, *comp_it)) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index c74326f631..5fcaa15c83 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -82,7 +83,7 @@ class BFloat16Propagation : public HloModulePass {
// The set of instructions to consider using bfloat16, computed in the forward
// pass.
- tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_;
+ absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_;
// ***************************
// Functions called and state produced by the backward pass (from root to
@@ -111,12 +112,12 @@ class BFloat16Propagation : public HloModulePass {
// The set of HloInstructions that have been visited in the
// opportunity-finding pass.
- tensorflow::gtl::FlatSet<const HloInstruction*>
+ absl::flat_hash_set<const HloInstruction*>
instructions_visited_in_backward_pass_;
// The set of HloComputations that have been visited in the
// opportunity-finding pass.
- tensorflow::gtl::FlatSet<const HloComputation*>
+ absl::flat_hash_set<const HloComputation*>
computations_visited_in_backward_pass_;
// ***************************
@@ -132,7 +133,7 @@ class BFloat16Propagation : public HloModulePass {
// point is reached.
bool ResolveInconsistencyOfAliasingBuffersHelper(
HloComputation* computation,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations);
+ absl::flat_hash_set<const HloComputation*>* visited_computations);
// Makes the parameters of called computations match how they are called by
// the given HLO.
@@ -183,7 +184,7 @@ class BFloat16Propagation : public HloModulePass {
PrimitiveType target_type);
// The set of F32 HLO values that must be kept in F32.
- tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
+ absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_;
// Mapping from each HloComputation to the number of callers to it in the
// module. Populated at the beginning of this pass.
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 3efa0b1dad..2c2d1626c2 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -43,9 +44,9 @@ namespace xla {
namespace {
using absl::flat_hash_map;
+using absl::flat_hash_set;
using absl::StrAppend;
using absl::StrAppendFormat;
-using ::tensorflow::gtl::FlatSet;
using ::tensorflow::strings::HumanReadableNumBytes;
template <typename T>
@@ -129,8 +130,8 @@ Status GatherComputationsByAllocationType(
// Sets for quickly checking membership. Computations are returned in vectors
// for stable iteration.
- FlatSet<const HloComputation*> thread_local_set;
- FlatSet<const HloComputation*> global_set;
+ flat_hash_set<const HloComputation*> thread_local_set;
+ flat_hash_set<const HloComputation*> global_set;
while (!worklist.empty()) {
auto worklist_front = worklist.front();
@@ -445,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex(
bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
const HloInstruction* hlo_b) const {
using SliceSet =
- FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
+ flat_hash_set<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
// Gets the slices all of instr's subshapes. If any subshape doesn't have an
// assigned slice, returns the empty set.
auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
@@ -815,9 +816,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
Status BufferAssigner::AssignBuffersForComputation(
const HloComputation* computation, bool is_thread_local,
- const FlatSet<const LogicalBuffer*>& colocated_buffers,
- const FlatSet<BufferAllocation::Index>& colocated_allocations,
- flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>*
+ const flat_hash_set<const LogicalBuffer*>& colocated_buffers,
+ const flat_hash_set<BufferAllocation::Index>& colocated_allocations,
+ flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>*
buffers_to_assign_sequentially,
BufferAssignment* assignment) {
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
@@ -853,8 +854,8 @@ Status BufferAssigner::AssignBuffersForComputation(
// buffers_to_assign_sequentially map, even if we end up with an empty set
// of buffers. This ensures we can correctly determine whether to run
// whole-module heap simulation.
- buffers_to_assign_sequentially->emplace(computation,
- FlatSet<const LogicalBuffer*>());
+ buffers_to_assign_sequentially->emplace(
+ computation, flat_hash_set<const LogicalBuffer*>());
}
// Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
@@ -1046,11 +1047,11 @@ Status BufferAssigner::AssignBuffersForComputation(
return Status::OK();
}
-flat_hash_map<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
+flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
LogicalBuffer::Color::Hasher>
BufferAssigner::SplitBuffersByColor(
- const FlatSet<const LogicalBuffer*>& buffers) {
- flat_hash_map<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
+ const flat_hash_set<const LogicalBuffer*>& buffers) {
+ flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
LogicalBuffer::Color::Hasher>
color_map;
for (auto buffer : buffers) {
@@ -1060,7 +1061,8 @@ BufferAssigner::SplitBuffersByColor(
}
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
- const flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>&
+ const flat_hash_map<const HloComputation*,
+ flat_hash_set<const LogicalBuffer*>>&
buffers_to_assign_sequentially,
bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
// Run the sequence of instructions through the heap simulator. The heuristic
@@ -1086,10 +1088,11 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
// only live for the duration of their calling instructions.
VLOG(1) << "Running whole-module heap simulation";
HloSchedule schedule(&assignment->module());
- FlatSet<const LogicalBuffer*> all_buffers_to_assign;
+ flat_hash_set<const LogicalBuffer*> all_buffers_to_assign;
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
- const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
+ const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
+ pair.second;
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
@@ -1123,7 +1126,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
VLOG(1) << "Running per-computation heap simulation";
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
- const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
+ const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
+ pair.second;
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
@@ -1198,7 +1202,7 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
// Next gather the set of logical buffers live at the earliest point of
// maximal live set size.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers;
+ absl::flat_hash_set<const LogicalBuffer*> live_buffers;
live_size = 0;
for (const auto& event : heap_trace.events()) {
const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
@@ -1588,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets(
void BufferAssigner::AssignColocatedBufferSets(
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
BufferAssignment* assignment,
- FlatSet<const LogicalBuffer*>* colocated_buffers,
- FlatSet<BufferAllocation::Index>* colocated_allocations) {
+ flat_hash_set<const LogicalBuffer*>* colocated_buffers,
+ flat_hash_set<BufferAllocation::Index>* colocated_allocations) {
for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
BufferAllocation* allocation = nullptr;
// Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry
@@ -1662,8 +1666,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// Once b/32491382 enables module-level liveness analysis, we may be able
// to assign colocated buffers (or at least reuse their allocation for
// buffers outside of the set) in AssignBuffersForComputation.
- FlatSet<const LogicalBuffer*> colocated_buffers;
- FlatSet<BufferAllocation::Index> colocated_allocations;
+ flat_hash_set<const LogicalBuffer*> colocated_buffers;
+ flat_hash_set<BufferAllocation::Index> colocated_allocations;
std::vector<ColocatedBufferSet> colocated_buffer_sets;
BuildColocatedBufferSets(module, assignment->liveness(),
assignment->buffer_size_, &colocated_buffer_sets);
@@ -1681,7 +1685,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// First assign buffers for global computatations. Temporary buffers for
// sequential computations are collected in 'buffers_to_assign_sequentially'.
- flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>
+ flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>
buffers_to_assign_sequentially;
for (auto* computation : global_computations) {
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 9ba40617a3..899cd36e1f 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -554,11 +554,10 @@ class BufferAssigner {
// true.
Status AssignBuffersForComputation(
const HloComputation* computation, bool is_thread_local,
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
- const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
- colocated_allocations,
+ const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers,
+ const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations,
absl::flat_hash_map<const HloComputation*,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
+ absl::flat_hash_set<const LogicalBuffer*>>*
buffers_to_assign_sequentially,
BufferAssignment* assignment);
@@ -569,7 +568,7 @@ class BufferAssigner {
// assuming all global computations are sequentially ordered.
Status AssignBuffersWithSequentialOrdering(
const absl::flat_hash_map<const HloComputation*,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
+ absl::flat_hash_set<const LogicalBuffer*>>&
buffers_to_assign_sequentially,
bool run_whole_module_heap_simulation, BufferAssignment* assignment);
@@ -589,7 +588,7 @@ class BufferAssigner {
// alias. Explicitly handling these colocated buffers is necessary because
// points-to analysis is computation level scope and does not recognize
// aliasing across computations (b/32491382).
- using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>;
+ using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>;
// Returns a vector of ColocatedBufferSet objects, where each
// ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
@@ -604,8 +603,8 @@ class BufferAssigner {
void AssignColocatedBufferSets(
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
BufferAssignment* assignment,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
- tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations);
+ absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers,
+ absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations);
// Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
// the invariant that all sets in 'colocated_buffer_sets' are disjoint.
@@ -624,10 +623,9 @@ class BufferAssigner {
// Split a set of buffers into several sets, each of which contains buffers
// colored with the same color.
absl::flat_hash_map<LogicalBuffer::Color,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>,
+ absl::flat_hash_set<const LogicalBuffer*>,
LogicalBuffer::Color::Hasher>
- SplitBuffersByColor(
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers);
+ SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers);
// If true, buffer assignments assumes that input parameter buffers and output
// buffers can be shared if their sizes match.
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h
index 2911bbcfbf..f939a426ea 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.h
+++ b/tensorflow/compiler/xla/service/buffer_liveness.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -101,7 +101,7 @@ class BufferLiveness {
// Set of LogicalBuffers which are aliased in the output of other
// instructions. For example, a LogicalBuffer which is inserted into a tuple
// is considered to be aliased and will be in this set.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_;
+ absl::flat_hash_set<const LogicalBuffer*> aliased_buffers_;
// LogicalBuffers that may be live out of the entry computation.
PointsToSet::BufferSet maybe_live_out_buffers_;
diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h
index 305914fca8..cc46af5eee 100644
--- a/tensorflow/compiler/xla/service/buffer_value_containers.h
+++ b/tensorflow/compiler/xla/service/buffer_value_containers.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet(
return output;
}
-using BufferValueFlatSet = tensorflow::gtl::FlatSet<const BufferValue*>;
+using BufferValueFlatSet = absl::flat_hash_set<const BufferValue*>;
template <class LogicalBufferContainerT>
BufferValueFlatSet ToBufferValueFlatSet(
const LogicalBufferContainerT& logical_buffer_container) {
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 23b2a32709..bdd5069632 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <queue>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
bool CallGraph::DominatesHelper(
const HloComputation* a, const HloComputation* b,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited) const {
+ absl::flat_hash_set<const HloComputation*>* visited) const {
if (a == b || ContainsKey(*visited, b)) {
// The call graph is guaranteed to be acyclic so any previously visited node
// we encounter was already determined to be dominated.
@@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper(
bool CallGraph::Dominates(const HloComputation* a,
const HloComputation* b) const {
- tensorflow::gtl::FlatSet<const HloComputation*> visited;
+ absl::flat_hash_set<const HloComputation*> visited;
return DominatesHelper(a, b, &visited);
}
@@ -277,7 +278,7 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
Status CallGraph::VisitNodesInternal(
const VisitorFunction& visitor_func, const CallGraphNode& node,
- tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const {
+ absl::flat_hash_set<const CallGraphNode*>* visited) const {
auto pair = visited->insert(&node);
if (!pair.second) {
// Node was not inserted. Node has already been visited.
@@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal(
Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
bool visit_unreachable_nodes) const {
- tensorflow::gtl::FlatSet<const CallGraphNode*> visited;
+ absl::flat_hash_set<const CallGraphNode*> visited;
if (visit_unreachable_nodes) {
// Traverse from all roots in the call graph.
for (const CallGraphNode& node : nodes()) {
diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h
index 0c2e9b99db..cb56f4789d 100644
--- a/tensorflow/compiler/xla/service/call_graph.h
+++ b/tensorflow/compiler/xla/service/call_graph.h
@@ -21,10 +21,10 @@ limitations under the License.
#include <ostream>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -145,12 +145,12 @@ class CallGraphNode {
// The computations called by this computation. The vector is used for a
// stable ordering and the set enables fast membership testing.
std::vector<HloComputation*> callees_;
- tensorflow::gtl::FlatSet<HloComputation*> callee_set_;
+ absl::flat_hash_set<HloComputation*> callee_set_;
// The computations which call this computation. The vector is used for a
// stable ordering and the set enables fast membership testing.
std::vector<HloComputation*> callers_;
- tensorflow::gtl::FlatSet<HloComputation*> caller_set_;
+ absl::flat_hash_set<HloComputation*> caller_set_;
// The call sites in this computation
std::vector<CallSite> callsites_;
@@ -250,14 +250,14 @@ class CallGraph {
// 'visited'.
Status VisitNodesInternal(
const VisitorFunction& visitor_func, const CallGraphNode& node,
- tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const;
+ absl::flat_hash_set<const CallGraphNode*>* visited) const;
// Recursive helper for computing whether 'a' dominates 'b' in the call
// graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
// and 'visited' is the set of computations which have been visited.
bool DominatesHelper(
const HloComputation* a, const HloComputation* b,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited) const;
+ absl::flat_hash_set<const HloComputation*>* visited) const;
// The HLO module represented by this call graph.
const HloModule* module_ = nullptr;
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 7f78412924..f35324aa35 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -904,7 +904,7 @@ class CopyRemover {
// The heads of all the value lists. Each value list represents the HLO
// values contained in a particular HLO buffer. The values in the list are
// in dependency order.
- tensorflow::gtl::FlatSet<const ValueNode*> value_lists_;
+ absl::flat_hash_set<const ValueNode*> value_lists_;
// Copy removal requires fast access to the value list elements
// corresponding to the source and destination values of the kCopy
@@ -1009,7 +1009,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
HloInstruction* root = computation->root_instruction();
// Mark nondistinct/ambiguous indices.
- tensorflow::gtl::FlatSet<const HloBuffer*> seen;
+ absl::flat_hash_set<const HloBuffer*> seen;
ShapeUtil::ForEachSubshape(
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const HloBuffer*> buffers_at_index =
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 6a83909a3b..ae4c6e962d 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -291,6 +291,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 953a75c35f..a70abb117a 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
@@ -68,7 +69,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -1400,8 +1400,8 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
// [0->0, 3->1].
absl::flat_hash_map<int64, int64> unreduced_dim_map;
- gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(),
- reduce.dimensions().end());
+ absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
+ reduce.dimensions().end());
const Shape& operand_shape = reduce.operand(0)->shape();
const Shape& result_shape = reduce.shape();
@@ -1977,7 +1977,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
//
// * Implement the memcpy within the innermost loop.
- gtl::FlatSet<int64> inner_dims;
+ absl::flat_hash_set<int64> inner_dims;
for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
break;
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index 7af51db55a..b35fd9dad8 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -121,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) {
CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]]
CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32
CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48
- CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]}
+ CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]}
CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]}
CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]}
)";
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index e65d3fa332..a838464cae 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -476,6 +476,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:instruction_fusion",
"//tensorflow/compiler/xla/service:pattern_matcher",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -508,6 +509,7 @@ cc_library(
"//tensorflow/compiler/xla/service:multi_output_fusion",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -541,6 +543,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index 79c74e7e8b..e2ab00ce41 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <set>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 4d5d8e99f8..b61f038739 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -125,8 +126,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
}
// Compute the precise number of operands to the new fusion.
- tensorflow::gtl::FlatSet<const HloInstruction*> operands(
- a->operands().begin(), a->operands().end());
+ absl::flat_hash_set<const HloInstruction*> operands(a->operands().begin(),
+ a->operands().end());
operands.insert(b->operands().begin(), b->operands().end());
// If there's an edge between `a` and `b`, don't count it: We're fusing that
// producer -> consumer relationship.
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index c21f76f6eb..835924024b 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -101,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
HloInstruction* instr2) {
- tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ absl::flat_hash_set<HloInstruction*> in_list;
for (auto instr : instr1->operands()) {
if (!IsProfitableOperand(instr)) {
continue;
@@ -148,7 +148,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
bool changed = false;
RecomputeReachability();
- tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
+ absl::flat_hash_set<HloInstruction*> to_fuse;
// Keep a list of the instructions to fuse after making all the fusion
// decisions. We first aggressively add instructions to potential_fusion_list,
// then filter out instructions that will be no longer fusible because of
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 147776c8c4..b343305554 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -26,7 +27,7 @@ limitations under the License.
namespace xla {
using absl::flat_hash_map;
-using tensorflow::gtl::FlatSet;
+using absl::flat_hash_set;
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
@@ -116,9 +117,9 @@ Status HeapSimulator::RunComputation(
// 'used_buffers' is the reverse map - it tracks which buffers were used by an
// instruction, so that we can remove the instructions from a buffer's live
// set after they are visited.
- flat_hash_map<const BufferValue*, FlatSet<const HloInstruction*>>
+ flat_hash_map<const BufferValue*, flat_hash_set<const HloInstruction*>>
live_buffers;
- flat_hash_map<const HloInstruction*, FlatSet<const BufferValue*>>
+ flat_hash_map<const HloInstruction*, flat_hash_set<const BufferValue*>>
used_buffers;
auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
const HloInstruction* user,
@@ -216,7 +217,7 @@ Status HeapSimulator::RunComputation(
VLOG(4) << " Removing user " << instruction->name() << " from buffer "
<< operand_buffer->ToString();
auto it = live_buffers.find(operand_buffer);
- FlatSet<const HloInstruction*>* live_set = &it->second;
+ flat_hash_set<const HloInstruction*>* live_set = &it->second;
live_set->erase(instruction);
if (live_set->empty()) {
live_buffers.erase(it);
@@ -238,7 +239,7 @@ Status HeapSimulator::RunComputation(
// that we should assign.
// Make sure each buffer get reused at most once.
- FlatSet<const BufferValue*> reused_buffers;
+ flat_hash_set<const BufferValue*> reused_buffers;
for (const BufferValue* buffer : buffers_defined_by_instruction) {
if (IgnoreBuffer(buffer)) {
continue;
@@ -326,7 +327,7 @@ Status HeapSimulator::RunComputation(
to_free.reserve(live_buffers.size());
for (const auto& buffer_pending : live_buffers) {
const BufferValue* buffer = buffer_pending.first;
- const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
+ const flat_hash_set<const HloInstruction*>& pending = buffer_pending.second;
CHECK_EQ(pending.size(), 1) << *buffer;
CHECK(*pending.begin() == nullptr) << *buffer;
to_free.push_back(buffer);
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index a5bb3f81f7..b0295a6163 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -197,8 +197,8 @@ class HeapSimulator {
shared_buffers_;
// Hold some sets for error-checking the sequence of Alloc and Free calls.
- tensorflow::gtl::FlatSet<const BufferValue*> allocated_buffers_;
- tensorflow::gtl::FlatSet<const BufferValue*> freed_buffers_;
+ absl::flat_hash_set<const BufferValue*> allocated_buffers_;
+ absl::flat_hash_set<const BufferValue*> freed_buffers_;
// Debugging information filled in while the heap simulator runs.
HeapSimulatorTrace debug_trace_;
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index b6e1f52cf5..c3da12e273 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -120,7 +121,7 @@ class BufferValueMap {
}
// Return a set of all the values in the given buffer.
- const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
+ const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer(
BufferNumber buffer_number) const {
return buffers_.at(buffer_number);
}
@@ -143,7 +144,7 @@ class BufferValueMap {
// Move the given value into the given buffer.
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
- tensorflow::gtl::FlatSet<const HloValue*>& old_value_set =
+ absl::flat_hash_set<const HloValue*>& old_value_set =
buffers_.at(old_buffer_number);
old_value_set.erase(&value);
if (old_value_set.empty()) {
@@ -291,7 +292,7 @@ class BufferValueMap {
const HloDataflowAnalysis& dataflow_;
// A map containing the set of values contained in each buffer.
- absl::flat_hash_map<BufferNumber, tensorflow::gtl::FlatSet<const HloValue*>>
+ absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>>
buffers_;
// A map indicating which buffer each value is contained in.
@@ -351,7 +352,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous(
bool HloAliasAnalysis::InstructionBuffersAreDistinct(
const HloInstruction* instruction) const {
- tensorflow::gtl::FlatSet<const HloBuffer*> buffers_seen;
+ absl::flat_hash_set<const HloBuffer*> buffers_seen;
for (const auto& pair :
dataflow_analysis_->GetInstructionValueSet(instruction)) {
const HloValueSet& value_set = pair.second;
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc
index 6c11a073b7..9c3aa0e64d 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.cc
+++ b/tensorflow/compiler/xla/service/hlo_buffer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 257dd5876f..6ef67ab0a8 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
@@ -40,7 +41,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -278,10 +278,9 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
namespace {
// Helper which builds a post order of the HLO call graph.
-void ComputeComputationPostOrder(
- HloComputation* computation,
- tensorflow::gtl::FlatSet<HloComputation*>* visited,
- std::vector<HloComputation*>* post_order) {
+void ComputeComputationPostOrder(HloComputation* computation,
+ absl::flat_hash_set<HloComputation*>* visited,
+ std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@@ -416,7 +415,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
- tensorflow::gtl::FlatSet<HloComputation*> visited;
+ absl::flat_hash_set<HloComputation*> visited;
std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index af929ac009..d87ab4bda1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -26,6 +26,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -41,7 +42,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index b59c9ba3ed..e602107cbe 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace xla {
@@ -137,8 +137,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
// HLO instructions are grouped into equivalency classes by using the
// cse_equal predicate defined above. This set holds a representative
// instruction for each class.
- tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash),
- decltype(cse_equal)>
+ absl::flat_hash_set<HloInstruction*, decltype(&CseHash),
+ decltype(cse_equal)>
representatives(/*N=*/computation->instruction_count() + 1, &CseHash,
cse_equal);
for (auto instruction : computation->MakeInstructionPostOrder()) {
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 6a63681996..44cde4a3d2 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
@@ -91,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis(
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
const HloInstruction* inst) {
- tensorflow::gtl::FlatSet<const HloInstruction*> visited;
+ absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(inst);
while (!stack.empty()) {
@@ -159,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
void HloDataflowAnalysis::DeleteMarkedValues() {
#ifndef NDEBUG
// Verify that no marked-for-deletion values are in any of the value sets.
- tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
- value_ids_to_delete_.end());
+ absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
+ value_ids_to_delete_.end());
for (const auto& pair : value_sets_) {
const HloInstruction* instruction = pair.first;
const InstructionValueSet& instruction_value_set = pair.second;
@@ -673,7 +674,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
void HloDataflowAnalysis::Propagate() {
std::queue<HloInstruction*> worklist;
- tensorflow::gtl::FlatSet<HloInstruction*> workset;
+ absl::flat_hash_set<HloInstruction*> workset;
auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
if (workset.insert(instruction).second) {
worklist.push(instruction);
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 159c39d557..6ca1255ede 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -217,7 +218,7 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
/* static */ std::vector<HloInstruction*>
HloDomainMap::MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const absl::flat_hash_set<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order) {
std::vector<HloInstruction*> instructions;
instructions.reserve(instruction_set.size());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index 8584bc021d..c8d581b746 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -20,13 +20,13 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -110,7 +110,7 @@ class HloDomainMap {
// Out of an instruction set, returns a vector of all the ones which are not
// a kDomain kind.
static std::vector<HloInstruction*> MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const absl::flat_hash_set<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order);
// Populates domain_metadata_id_ that maps each HloInstruction to the unique
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index 302807f816..d3c83c15ae 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -42,7 +42,7 @@ class DomainMetadata {
// operand/user pathways, without crossing a kDomain instruction of a given
// kind. The reach_set can contain kDomain instructions of other kinds, if
// two domains of different kind intersect each other.
- tensorflow::gtl::FlatSet<HloInstruction*> reach_set;
+ absl::flat_hash_set<HloInstruction*> reach_set;
// The same instructions in reach_set, but purged from kDomain instructions
// and ordered according to their computation graph post-order, i.e.
@@ -55,8 +55,8 @@ class DomainMetadata {
// whose dataflow enters the reach set (domain), while the exit_domains
// contains the set of kDomain instructions whose dataflow exit the reach
// set.
- tensorflow::gtl::FlatSet<HloInstruction*> enter_domains;
- tensorflow::gtl::FlatSet<HloInstruction*> exit_domains;
+ absl::flat_hash_set<HloInstruction*> enter_domains;
+ absl::flat_hash_set<HloInstruction*> exit_domains;
};
virtual ~DomainMetadata() = default;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 5d5c9c7e58..0207f9ae3f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
@@ -1433,7 +1433,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const {
HloInstruction::InstructionVector HloInstruction::unique_operands() const {
InstructionVector unique;
- tensorflow::gtl::FlatSet<const HloInstruction*> seen;
+ absl::flat_hash_set<const HloInstruction*> seen;
for (HloInstruction* operand : operands()) {
if (seen.insert(operand).second) {
unique.push_back(operand);
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 1c2b2868fd..55314d0ae9 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -111,7 +112,7 @@ class ListScheduler {
// LogicalBuffer is in an operand of the instruction as indicated by
// points-to analysis.
for (auto* instruction : computation.instructions()) {
- tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses;
+ absl::flat_hash_set<const LogicalBuffer*> instr_uses;
for (auto* operand : instruction->operands()) {
points_to_analysis.GetPointsToSet(operand).ForEachElement(
[&](const ShapeIndex& /*index*/,
@@ -360,7 +361,7 @@ class ListScheduler {
std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
// Set of instructions which have been scheduled.
- tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;
+ absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
};
int64 SumLogicalBufferSizes(
@@ -418,7 +419,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
total_sizes[hlo] = logical_buffer_size;
cumulative_total_size += logical_buffer_size;
- tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
+ absl::flat_hash_set<const HloInstruction*> unique_operands(
hlo->operands().begin(), hlo->operands().end());
for (const HloInstruction* operand : unique_operands) {
extra_users[hlo] += extra_users[operand];
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 9359e9a8be..7527e35c95 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -328,10 +329,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Because we didn't uniquify the names or the ids, double-check that the
// instruction and computation names and ids are unique from the proto.
- tensorflow::gtl::FlatSet<string> computation_names;
- tensorflow::gtl::FlatSet<string> instruction_names;
- tensorflow::gtl::FlatSet<int> computation_ids;
- tensorflow::gtl::FlatSet<int> instruction_ids;
+ absl::flat_hash_set<string> computation_names;
+ absl::flat_hash_set<string> instruction_names;
+ absl::flat_hash_set<int> computation_ids;
+ absl::flat_hash_set<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index d83ee71490..fddeb5f0a2 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -42,7 +42,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
HloInstruction* instruction) {
std::vector<HloInstruction*>
predecessors; // Use a vector to avoid non-determinism.
- tensorflow::gtl::FlatSet<HloInstruction*> unique;
+ absl::flat_hash_set<HloInstruction*> unique;
// Adds to the unique predecessors list; if the predecessors is a companion
// instruction, also add companion instructions; if the predecessors is a
@@ -119,7 +119,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
HloInstruction* instruction) {
std::vector<HloInstruction*>
successors; // Use a vector to avoid non-determinism.
- tensorflow::gtl::FlatSet<HloInstruction*> unique;
+ absl::flat_hash_set<HloInstruction*> unique;
// Adds to the unique successors list; if the successor is a companion
// instruction, also add companion instructions; if the successor is a
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 59fd01cb58..5e004ce78a 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -75,8 +75,8 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
const DebugOptions& debug_options) {
auto repeated_field = debug_options.xla_disable_hlo_passes();
- tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
- repeated_field.end());
+ absl::flat_hash_set<string> disabled_pass_names(repeated_field.begin(),
+ repeated_field.end());
if (!disabled_pass_names.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
<< absl::StrJoin(disabled_pass_names, ", ");
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index abdd9a9212..5ac43808ee 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -981,7 +982,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// rematerialization is essentially a move). If the next rematerialization of
// the instruction is also a move then the rematerialization is added to the
// blacklist.
- tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
+ absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
// The map from instructions to their rematerializable status.
absl::flat_hash_map<const HloInstruction*, bool> remat_able;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 5a02e3a8bb..70d83c04f0 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -16,6 +16,7 @@
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -122,7 +123,7 @@ class HloRematerialization : public HloModulePass {
// Set of computations which have had rematerialization
// applied. Rematerialization is only applied once per computation.
- tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_;
+ absl::flat_hash_set<const HloComputation*> rematerialized_computations_;
// Count of the total instructions rematerialized.
int64 instructions_rematerialized_ = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
index 7c5c98f04e..9972eb2077 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -119,7 +120,7 @@ Status HloSchedule::UpdateComputationSchedule(
}
// Set of all HloInstructions in the schedule.
- tensorflow::gtl::FlatSet<int> ids_in_schedule;
+ absl::flat_hash_set<int> ids_in_schedule;
for (int id : sequences_.at(computation->unique_id()).ids()) {
InsertOrDie(&ids_in_schedule, id);
}
@@ -210,7 +211,7 @@ Status HloSchedule::Update() {
if (sequences_.size() > nonfusion_computations.size()) {
// Schedule contains some computations which have been removed from the
// HloModule. Remove them from the schedule as well.
- tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids;
+ absl::flat_hash_set<int64> nonfusion_computations_ids;
for (const HloComputation* computation : nonfusion_computations) {
nonfusion_computations_ids.insert(computation->unique_id());
}
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 8549487702..59594ab2f0 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -167,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses(
positions_.insert(positions_.end(), positions.begin(), positions.end());
// Gather the computation roots at which this value appears.
- tensorflow::gtl::FlatSet<HloInstruction*> root_positions;
+ absl::flat_hash_set<HloInstruction*> root_positions;
for (const HloPosition& position : positions_) {
if (position.instruction ==
position.instruction->parent()->root_instruction()) {
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 7ee789276d..1ebb331977 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
namespace gtl = ::tensorflow::gtl;
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index 1591256fad..15f0adcaaf 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -26,6 +26,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -39,7 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -504,7 +504,7 @@ class LayoutAssignment : public HloModulePass {
// Every copy added to the module by the layout assignment pass is registered
// here.
- tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
+ absl::flat_hash_set<HloInstruction*> added_copies_;
// The pointer to the channel layout constraints passed in with the
// constructor. If not nullptr, this is an input/output argument.
@@ -521,8 +521,7 @@ class LayoutAssignment : public HloModulePass {
// The set of HLO instructions which lacked any layout constraint, thus
// receiving propagated default layouts.
- tensorflow::gtl::FlatSet<const HloInstruction*>
- unconstrained_layout_instructions_;
+ absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 3934d2e493..6223a34b12 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -39,6 +39,7 @@ cc_library(
"//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm//:core",
],
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
index e5370eca56..643ecd0fba 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
-#include <unordered_set>
+#include <map>
#include "llvm/IR/MDBuilder.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer(
add_buffers_to_worklist(operand);
}
- tensorflow::gtl::FlatSet<BufferAllocation::Slice,
- BufferAllocation::Slice::Hasher>
- buffers;
+ std::set<BufferAllocation::Slice> buffers;
for (const LogicalBuffer* buffer : worklist) {
// Skip buffers which cannot be added to the noalias set.
if (!assignment.HasAllocation(*buffer) ||
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
index 88cde2d3d9..2b46b3c396 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
namespace llvm_ir {
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 95b1c20663..2ca527bc4c 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -15,10 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -50,7 +50,7 @@ StatusOr<bool> MultiOutputFusion::Run(HloModule* module) {
all_fusion_candidates_.push_back(instruction);
std::vector<HloInstruction*> candidates;
- tensorflow::gtl::FlatSet<HloInstruction*> candidates_set;
+ absl::flat_hash_set<HloInstruction*> candidates_set;
VLOG(10) << "Looking at instruction: " << instruction->name();
for (auto operand : instruction->operands()) {
// Filter out the non-interesting instructions -- they
@@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
// Update the fusible list for fusion. Variable new_fusibles keeps
// track of the new or changed entries.
std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
- tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ absl::flat_hash_set<HloInstruction*> in_list;
auto it = fusion_node.fusibles.begin();
while (it != fusion_node.fusibles.end()) {
HloInstruction* instr = it->first;
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index 1ac60f1cf4..8909d0f4fe 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <string>
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -69,7 +69,7 @@ class NameUniquer {
int64 next_ = 0;
// Set of all the identifiers which has been used.
- tensorflow::gtl::FlatSet<int64> used_;
+ absl::flat_hash_set<int64> used_;
};
// The string to use to separate the prefix of the name from the uniquing
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 6ccea9d2b5..e379911462 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -577,7 +577,7 @@ Status ValidateDotDimensionNumbers(
// Check that dimension numbers are unique.
auto dims_unique = [](absl::Span<const int64> contracting_dims,
absl::Span<const int64> batch_dims) -> bool {
- tensorflow::gtl::FlatSet<int64> dim_set;
+ absl::flat_hash_set<int64> dim_set;
auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second;
};
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 921a984589..56952e3ada 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -147,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() {
// Deallocate all non-null buffers. A buffer may appear in more than one spot
// in the shape (eg, a tuple with a repeated element) so keep track of what
// has been deallocated.
- tensorflow::gtl::FlatSet<void*> deallocated_ptrs;
+ absl::flat_hash_set<void*> deallocated_ptrs;
for (auto& pair : buffers_) {
se::DeviceMemoryBase& memory_base = pair.second;
if (!memory_base.is_null() &&
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 78392d3bb2..64ad1dc80e 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index 2590473c77..9795b2830b 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -16,17 +16,17 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
using absl::flat_hash_map;
+using absl::flat_hash_set;
using absl::InlinedVector;
-using tensorflow::gtl::FlatSet;
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
// operands as needed. All of its transitive operands are expected to be either
@@ -35,7 +35,7 @@ using tensorflow::gtl::FlatSet;
// them into `hoisted_instructions`.
static void CreateLoopInvariantCopy(
flat_hash_map<HloInstruction*, HloInstruction*>* hoisted_instructions,
- FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
+ flat_hash_set<HloInstruction*>* unhoisted_invariant_instructions,
HloInstruction* while_instr, HloInstruction* to_hoist) {
HloComputation* parent_of_while = while_instr->parent();
HloComputation* while_body = while_instr->while_body();
@@ -153,7 +153,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
// unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we
// hoist an instruction in this set, we move it from
// unhoisted_invariant_instructions to hoisted_instructions.
- FlatSet<HloInstruction*> unhoisted_invariant_instructions;
+ flat_hash_set<HloInstruction*> unhoisted_invariant_instructions;
// Invariant GTE's axiomatically satisfy the constraints for
// unhoisted_invariant_instructions -- they can be legally hoisted, but there
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 07de8492ba..630d71e5ca 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
@@ -114,7 +115,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
return false;
}
- tensorflow::gtl::FlatSet<int64> used_tuple_indices;
+ absl::flat_hash_set<int64> used_tuple_indices;
for (HloComputation* comp : {while_body, while_cond}) {
// The HLO verifier ensures that while_input's shape matches while_init's
// shape, which we verified above is a tuple.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 06b6330321..8a0ae33042 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2146,11 +2146,11 @@ xla_test(
":test_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 181e5cbe29..bc433eac8f 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -145,7 +146,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
ASSERT_EQ(args.size(), 2);
const Literal& key_arg = args[0];
- tensorflow::gtl::FlatSet<uint32> key_set;
+ absl::flat_hash_set<uint32> key_set;
for (const float& value : key_arg.data<float>()) {
EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
}
@@ -168,7 +169,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
ASSERT_EQ(args.size(), 2);
const Literal& key_arg = args[0];
- tensorflow::gtl::FlatSet<int32> key_set;
+ absl::flat_hash_set<int32> key_set;
for (const int32& value : key_arg.data<int32>()) {
EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
}