diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-10-01 19:42:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 19:46:29 -0700 |
commit | 991f06fd50fc73285ce415d57f720994c2b2e861 (patch) | |
tree | 0b87402386aad22ec958f171bfd57f9c7c3e8571 | |
parent | beede8525be5386451bf0098992c37416d1864db (diff) |
[XLA] Migrate from gtl::FlatSet to absl::flat_hash_set
PiperOrigin-RevId: 215324035
63 files changed, 235 insertions, 186 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f20270931f..661b444a42 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -325,6 +325,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -407,6 +408,7 @@ cc_library( "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index e63d4b7792..e0b9932d80 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -16,11 +16,11 @@ limitations under the License. #include "tensorflow/compiler/jit/deadness_analysis.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW @@ -298,7 +298,7 @@ class SymbolPredicate : public Predicate { template <typename FunctionTy> /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { - gtl::FlatSet<Predicate*> visited; + absl::flat_hash_set<Predicate*> visited; std::vector<Predicate*> stack; stack.push_back(p); @@ -467,7 +467,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; Predicate::Kind other_pred_kind = is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; - gtl::FlatSet<Predicate*> simplified_ops_set; + absl::flat_hash_set<Predicate*> simplified_ops_set; std::vector<Predicate*> simplified_ops; for (Predicate* op : operands) { // Simplify A&A => A and A|A => A. @@ -492,7 +492,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( } // Simplify "A&~A=>False" and "A|~A=>True". - gtl::FlatSet<Predicate*> negated_ops; + absl::flat_hash_set<Predicate*> negated_ops; for (Predicate* op : simplified_ops) { if (op->kind() == Predicate::Kind::kNot) { negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand()); @@ -512,7 +512,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( // // First find any predicates contained in all subops. std::vector<Predicate*> common_inner_operands; - gtl::FlatSet<Predicate*> common_inner_operands_set; + absl::flat_hash_set<Predicate*> common_inner_operands_set; for (Predicate* op : simplified_ops) { if (op->kind() != other_pred_kind) { common_inner_operands.clear(); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index d165341f21..da27f837e8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/public/session_options.h" @@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) { namespace { bool AreAllParentsGuaranteedConst( - const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) { + const Node& n, + const absl::flat_hash_set<const Node*>& runtime_const_nodes) { if (n.type_string() == "GuaranteeConst") { // If the current node is itself a cast-to-const, no need // to look at the incoming edges. @@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst( void MarkGuaranteedConstants( const Graph& graph, const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) { - gtl::FlatSet<const Node*> guaranteed_const_nodes; + absl::flat_hash_set<const Node*> guaranteed_const_nodes; std::vector<const Node*> srcs; srcs.reserve(src_arg_pairs.size()); for (const auto& src_arg : src_arg_pairs) { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 755c364c62..2ce6fa73fc 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) { } // Adds the control inputs of `node` to `*deps`. -void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) { +void AddControlInputs(const Node& node, absl::flat_hash_set<Node*>* deps) { for (const Edge* edge : node.in_edges()) { if (edge->IsControlEdge()) { deps->insert(edge->src()); @@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) { } // Adds the control outputs of `node` to `*deps`. -void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) { +void AddControlOutputs(const Node& node, absl::flat_hash_set<Node*>* deps) { for (const Edge* edge : node.out_edges()) { if (edge->IsControlEdge()) { deps->insert(edge->dst()); @@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, // Data and control inputs to the new XlaLaunch node. std::vector<std::pair<Node*, int>> data_inputs(num_inputs); - gtl::FlatSet<Node*> control_inputs; + absl::flat_hash_set<Node*> control_inputs; DataTypeVector arg_types(num_args); AddControlInputs(*launch, &control_inputs); @@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, // Outputs. const int num_outputs = launch->output_types().size(); - gtl::FlatSet<Node*> control_outputs; + absl::flat_hash_set<Node*> control_outputs; std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs); DataTypeVector output_types(num_outputs); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 133d982360..4f0c370e65 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -371,7 +371,7 @@ bool IsXlaFusable(const NodeDef& node) { Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn, - OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) { + OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) { OptimizerOptions opts; std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -849,7 +849,7 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); OrderedNodeSet compilation_candidates; - gtl::FlatSet<Node*> isolated_nodes; + absl::flat_hash_set<Node*> isolated_nodes; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 10fc9e85d9..b1f9e9088f 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -15,17 +15,18 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace tensorflow { namespace { -Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result, +Status FindNodesToDecluster(const Graph& graph, + absl::flat_hash_set<Node*>* result, absl::Span<Node* const> post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to @@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/NotBackedge); - gtl::FlatSet<Node*> nodes_to_partially_decluster; + absl::flat_hash_set<Node*> nodes_to_partially_decluster; TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 657bb409db..e039d46ec8 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" @@ -89,7 +90,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/util/ptr_util.h" @@ -176,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) { // point. class ResourceOpSet { private: - using Impl = gtl::FlatSet<ResourceOp>; + using Impl = absl::flat_hash_set<ResourceOp>; public: ResourceOpSet() = default; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 3cf74fa788..822fedf121 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1105,6 +1105,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index bddda6f302..7a96f4c25c 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,7 @@ limitations under the License. #include <random> #include <unordered_map> +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" @@ -63,7 +64,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -457,7 +457,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { - gtl::FlatSet<float> already_generated; + absl::flat_hash_set<float> already_generated; std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); test::FillFn<float>(&tensor, [&](int i) -> float { float generated; @@ -470,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_DOUBLE: { - gtl::FlatSet<double> already_generated; + absl::flat_hash_set<double> already_generated; std::uniform_real_distribution<double> distribution(-1.0, 1.0); test::FillFn<double>(&tensor, [&](int i) -> double { double generated; @@ -483,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_COMPLEX64: { - gtl::FlatSet<std::pair<float, float>> already_generated; + absl::flat_hash_set<std::pair<float, float>> already_generated; std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); test::FillFn<complex64>(&tensor, [&](int i) { complex64 generated; @@ -500,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT32: { - gtl::FlatSet<int32> already_generated; + absl::flat_hash_set<int32> already_generated; std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20); test::FillFn<int32>(&tensor, [&](int i) -> int32 { int32 generated; @@ -513,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT64: { - gtl::FlatSet<int64> already_generated; + absl::flat_hash_set<int64> already_generated; std::uniform_int_distribution<int64> distribution(-(1LL << 40), 1LL << 40); test::FillFn<int64>(&tensor, [&](int i) -> int64 { @@ -527,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_BOOL: { - gtl::FlatSet<bool> already_generated; + absl::flat_hash_set<bool> already_generated; std::bernoulli_distribution distribution; test::FillFn<bool>(&tensor, [&](int i) -> bool { bool generated; diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 1191cff109..dc097f3696 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -221,6 +221,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 5277de6a85..e0ec91dba1 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -2290,7 +2290,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph( // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set<int64> related_ops; - tensorflow::gtl::FlatSet<int64> related_calls; // Related computations. + absl::flat_hash_set<int64> related_calls; // Related computations. std::queue<int64> worklist; worklist.push(root->id()); related_ops.insert(root->id()); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index b7295e8a53..cd0d5ca5d3 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/types.h" @@ -1035,7 +1035,7 @@ class XlaBuilder { std::map<int64, HloComputationProto> embedded_; // The unique parameter numbers. - tensorflow::gtl::FlatSet<int64> parameter_numbers_; + absl::flat_hash_set<int64> parameter_numbers_; // The metadata to attach to each op. This is structured as a "modal"-like // operation, in order to simplify client code (and not sprinkle this metadata diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 8da6364786..13803f5ebe 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -147,6 +147,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -183,6 +184,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -336,6 +338,7 @@ cc_library( "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -490,6 +493,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -781,6 +785,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -959,6 +964,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -995,6 +1001,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -1043,6 +1050,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1136,6 +1144,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1230,6 +1239,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1275,6 +1285,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1348,6 +1359,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -1660,6 +1672,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -2064,6 +2077,7 @@ cc_library( ":logical_buffer", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2099,6 +2113,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -2120,6 +2135,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2203,6 +2219,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2225,6 +2242,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -2286,6 +2304,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2343,6 +2362,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2370,6 +2390,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2487,6 +2508,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2616,6 +2638,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2655,6 +2678,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2730,6 +2754,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -3300,6 +3325,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3387,6 +3413,7 @@ cc_library( "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 58f78f8e24..002be9c970 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( }; auto root = fusion->fused_instructions_computation()->root_instruction(); - tensorflow::gtl::FlatSet<const HloValue*> changed_root_buffers; + absl::flat_hash_set<const HloValue*> changed_root_buffers; auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { @@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations) { + absl::flat_hash_set<const HloComputation*>* visited_computations) { bool parameter_changed = false; auto insts = computation->MakeInstructionPostOrder(); // Do the adjustment on each instruction in the computation in reverse @@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( // another input parameter. A fixed point will be reached because the // parameters can only be changed from BF16 to F32, not the other way // around. - tensorflow::gtl::FlatSet<const HloComputation*> visited_in_while; + absl::flat_hash_set<const HloComputation*> visited_in_while; while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), &visited_in_while) || ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), @@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { const auto& computations_topological_order = module->MakeComputationPostOrder(); - tensorflow::gtl::FlatSet<const HloComputation*> resolved; + absl::flat_hash_set<const HloComputation*> resolved; for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { if (ContainsKey(resolved, *comp_it)) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index c74326f631..5fcaa15c83 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/bfloat16_support.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -82,7 +83,7 @@ class BFloat16Propagation : public HloModulePass { // The set of instructions to consider using bfloat16, computed in the forward // pass. - tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_; + absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_; // *************************** // Functions called and state produced by the backward pass (from root to @@ -111,12 +112,12 @@ class BFloat16Propagation : public HloModulePass { // The set of HloInstructions that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet<const HloInstruction*> + absl::flat_hash_set<const HloInstruction*> instructions_visited_in_backward_pass_; // The set of HloComputations that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet<const HloComputation*> + absl::flat_hash_set<const HloComputation*> computations_visited_in_backward_pass_; // *************************** @@ -132,7 +133,7 @@ class BFloat16Propagation : public HloModulePass { // point is reached. bool ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations); + absl::flat_hash_set<const HloComputation*>* visited_computations); // Makes the parameters of called computations match how they are called by // the given HLO. @@ -183,7 +184,7 @@ class BFloat16Propagation : public HloModulePass { PrimitiveType target_type); // The set of F32 HLO values that must be kept in F32. - tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_; + absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_; // Mapping from each HloComputation to the number of callers to it in the // module. Populated at the beginning of this pass. diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 3efa0b1dad..2c2d1626c2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -23,6 +23,7 @@ limitations under the License. #include <utility> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -43,9 +44,9 @@ namespace xla { namespace { using absl::flat_hash_map; +using absl::flat_hash_set; using absl::StrAppend; using absl::StrAppendFormat; -using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::HumanReadableNumBytes; template <typename T> @@ -129,8 +130,8 @@ Status GatherComputationsByAllocationType( // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet<const HloComputation*> thread_local_set; - FlatSet<const HloComputation*> global_set; + flat_hash_set<const HloComputation*> thread_local_set; + flat_hash_set<const HloComputation*> global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); @@ -445,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { using SliceSet = - FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>; + flat_hash_set<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -815,9 +816,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const FlatSet<const LogicalBuffer*>& colocated_buffers, - const FlatSet<BufferAllocation::Index>& colocated_allocations, - flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>* + const flat_hash_set<const LogicalBuffer*>& colocated_buffers, + const flat_hash_set<BufferAllocation::Index>& colocated_allocations, + flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>* buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of @@ -853,8 +854,8 @@ Status BufferAssigner::AssignBuffersForComputation( // buffers_to_assign_sequentially map, even if we end up with an empty set // of buffers. This ensures we can correctly determine whether to run // whole-module heap simulation. - buffers_to_assign_sequentially->emplace(computation, - FlatSet<const LogicalBuffer*>()); + buffers_to_assign_sequentially->emplace( + computation, flat_hash_set<const LogicalBuffer*>()); } // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers @@ -1046,11 +1047,11 @@ Status BufferAssigner::AssignBuffersForComputation( return Status::OK(); } -flat_hash_map<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>, +flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>, LogicalBuffer::Color::Hasher> BufferAssigner::SplitBuffersByColor( - const FlatSet<const LogicalBuffer*>& buffers) { - flat_hash_map<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>, + const flat_hash_set<const LogicalBuffer*>& buffers) { + flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>, LogicalBuffer::Color::Hasher> color_map; for (auto buffer : buffers) { @@ -1060,7 +1061,8 @@ BufferAssigner::SplitBuffersByColor( } Status BufferAssigner::AssignBuffersWithSequentialOrdering( - const flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>& + const flat_hash_map<const HloComputation*, + flat_hash_set<const LogicalBuffer*>>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic @@ -1086,10 +1088,11 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; HloSchedule schedule(&assignment->module()); - FlatSet<const LogicalBuffer*> all_buffers_to_assign; + flat_hash_set<const LogicalBuffer*> all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second; + const flat_hash_set<const LogicalBuffer*>& buffers_to_assign = + pair.second; const std::vector<const HloInstruction*>* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1123,7 +1126,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(1) << "Running per-computation heap simulation"; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second; + const flat_hash_set<const LogicalBuffer*>& buffers_to_assign = + pair.second; const std::vector<const HloInstruction*>* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1198,7 +1202,7 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers( // Next gather the set of logical buffers live at the earliest point of // maximal live set size. - tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers; + absl::flat_hash_set<const LogicalBuffer*> live_buffers; live_size = 0; for (const auto& event : heap_trace.events()) { const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); @@ -1588,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets( void BufferAssigner::AssignColocatedBufferSets( const std::vector<ColocatedBufferSet>& colocated_buffer_sets, BufferAssignment* assignment, - FlatSet<const LogicalBuffer*>* colocated_buffers, - FlatSet<BufferAllocation::Index>* colocated_allocations) { + flat_hash_set<const LogicalBuffer*>* colocated_buffers, + flat_hash_set<BufferAllocation::Index>* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry @@ -1662,8 +1666,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( // Once b/32491382 enables module-level liveness analysis, we may be able // to assign colocated buffers (or at least reuse their allocation for // buffers outside of the set) in AssignBuffersForComputation. - FlatSet<const LogicalBuffer*> colocated_buffers; - FlatSet<BufferAllocation::Index> colocated_allocations; + flat_hash_set<const LogicalBuffer*> colocated_buffers; + flat_hash_set<BufferAllocation::Index> colocated_allocations; std::vector<ColocatedBufferSet> colocated_buffer_sets; BuildColocatedBufferSets(module, assignment->liveness(), assignment->buffer_size_, &colocated_buffer_sets); @@ -1681,7 +1685,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( // First assign buffers for global computatations. Temporary buffers for // sequential computations are collected in 'buffers_to_assign_sequentially'. - flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>> + flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>> buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 9ba40617a3..899cd36e1f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -23,6 +23,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -554,11 +554,10 @@ class BufferAssigner { // true. Status AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers, - const tensorflow::gtl::FlatSet<BufferAllocation::Index>& - colocated_allocations, + const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers, + const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations, absl::flat_hash_map<const HloComputation*, - tensorflow::gtl::FlatSet<const LogicalBuffer*>>* + absl::flat_hash_set<const LogicalBuffer*>>* buffers_to_assign_sequentially, BufferAssignment* assignment); @@ -569,7 +568,7 @@ class BufferAssigner { // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( const absl::flat_hash_map<const HloComputation*, - tensorflow::gtl::FlatSet<const LogicalBuffer*>>& + absl::flat_hash_set<const LogicalBuffer*>>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment); @@ -589,7 +588,7 @@ class BufferAssigner { // alias. Explicitly handling these colocated buffers is necessary because // points-to analysis is computation level scope and does not recognize // aliasing across computations (b/32491382). - using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>; + using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>; // Returns a vector of ColocatedBufferSet objects, where each // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' @@ -604,8 +603,8 @@ class BufferAssigner { void AssignColocatedBufferSets( const std::vector<ColocatedBufferSet>& colocated_buffer_sets, BufferAssignment* assignment, - tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers, - tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations); + absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers, + absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations); // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining // the invariant that all sets in 'colocated_buffer_sets' are disjoint. @@ -624,10 +623,9 @@ class BufferAssigner { // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. absl::flat_hash_map<LogicalBuffer::Color, - tensorflow::gtl::FlatSet<const LogicalBuffer*>, + absl::flat_hash_set<const LogicalBuffer*>, LogicalBuffer::Color::Hasher> - SplitBuffersByColor( - const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers); + SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers); // If true, buffer assignments assumes that input parameter buffers and output // buffers can be shared if their sizes match. diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 2911bbcfbf..f939a426ea 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -101,7 +101,7 @@ class BufferLiveness { // Set of LogicalBuffers which are aliased in the output of other // instructions. For example, a LogicalBuffer which is inserted into a tuple // is considered to be aliased and will be in this set. - tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_; + absl::flat_hash_set<const LogicalBuffer*> aliased_buffers_; // LogicalBuffers that may be live out of the entry computation. PointsToSet::BufferSet maybe_live_out_buffers_; diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h index 305914fca8..cc46af5eee 100644 --- a/tensorflow/compiler/xla/service/buffer_value_containers.h +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet( return output; } -using BufferValueFlatSet = tensorflow::gtl::FlatSet<const BufferValue*>; +using BufferValueFlatSet = absl::flat_hash_set<const BufferValue*>; template <class LogicalBufferContainerT> BufferValueFlatSet ToBufferValueFlatSet( const LogicalBufferContainerT& logical_buffer_container) { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 23b2a32709..bdd5069632 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,6 +17,7 @@ limitations under the License. #include <queue> +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { bool CallGraph::DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet<const HloComputation*>* visited) const { + absl::flat_hash_set<const HloComputation*>* visited) const { if (a == b || ContainsKey(*visited, b)) { // The call graph is guaranteed to be acyclic so any previously visited node // we encounter was already determined to be dominated. @@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper( bool CallGraph::Dominates(const HloComputation* a, const HloComputation* b) const { - tensorflow::gtl::FlatSet<const HloComputation*> visited; + absl::flat_hash_set<const HloComputation*> visited; return DominatesHelper(a, b, &visited); } @@ -277,7 +278,7 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) { Status CallGraph::VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const { + absl::flat_hash_set<const CallGraphNode*>* visited) const { auto pair = visited->insert(&node); if (!pair.second) { // Node was not inserted. Node has already been visited. @@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal( Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes) const { - tensorflow::gtl::FlatSet<const CallGraphNode*> visited; + absl::flat_hash_set<const CallGraphNode*> visited; if (visit_unreachable_nodes) { // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 0c2e9b99db..cb56f4789d 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -21,10 +21,10 @@ limitations under the License. #include <ostream> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -145,12 +145,12 @@ class CallGraphNode { // The computations called by this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector<HloComputation*> callees_; - tensorflow::gtl::FlatSet<HloComputation*> callee_set_; + absl::flat_hash_set<HloComputation*> callee_set_; // The computations which call this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector<HloComputation*> callers_; - tensorflow::gtl::FlatSet<HloComputation*> caller_set_; + absl::flat_hash_set<HloComputation*> caller_set_; // The call sites in this computation std::vector<CallSite> callsites_; @@ -250,14 +250,14 @@ class CallGraph { // 'visited'. Status VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const; + absl::flat_hash_set<const CallGraphNode*>* visited) const; // Recursive helper for computing whether 'a' dominates 'b' in the call // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), // and 'visited' is the set of computations which have been visited. bool DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet<const HloComputation*>* visited) const; + absl::flat_hash_set<const HloComputation*>* visited) const; // The HLO module represented by this call graph. const HloModule* module_ = nullptr; diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 7f78412924..f35324aa35 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -904,7 +904,7 @@ class CopyRemover { // The heads of all the value lists. Each value list represents the HLO // values contained in a particular HLO buffer. The values in the list are // in dependency order. - tensorflow::gtl::FlatSet<const ValueNode*> value_lists_; + absl::flat_hash_set<const ValueNode*> value_lists_; // Copy removal requires fast access to the value list elements // corresponding to the source and destination values of the kCopy @@ -1009,7 +1009,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. - tensorflow::gtl::FlatSet<const HloBuffer*> seen; + absl::flat_hash_set<const HloBuffer*> seen; ShapeUtil::ForEachSubshape( root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { std::vector<const HloBuffer*> buffers_at_index = diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6a83909a3b..ae4c6e962d 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -291,6 +291,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 953a75c35f..a70abb117a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -25,6 +25,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" @@ -68,7 +69,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -1400,8 +1400,8 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { // [0->0, 3->1]. absl::flat_hash_map<int64, int64> unreduced_dim_map; - gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(), - reduce.dimensions().end()); + absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(), + reduce.dimensions().end()); const Shape& operand_shape = reduce.operand(0)->shape(); const Shape& result_shape = reduce.shape(); @@ -1977,7 +1977,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // // * Implement the memcpy within the innermost loop. - gtl::FlatSet<int64> inner_dims; + absl::flat_hash_set<int64> inner_dims; for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 7af51db55a..b35fd9dad8 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -121,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) { CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 - CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]} + CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]} CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]} CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]} )"; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index e65d3fa332..a838464cae 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -476,6 +476,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/compiler/xla/service:pattern_matcher", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -508,6 +509,7 @@ cc_library( "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -541,6 +543,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 79c74e7e8b..e2ab00ce41 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include <set> #include <vector> +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 4d5d8e99f8..b61f038739 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -125,8 +126,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { } // Compute the precise number of operands to the new fusion. - tensorflow::gtl::FlatSet<const HloInstruction*> operands( - a->operands().begin(), a->operands().end()); + absl::flat_hash_set<const HloInstruction*> operands(a->operands().begin(), + a->operands().end()); operands.insert(b->operands().begin(), b->operands().end()); // If there's an edge between `a` and `b`, don't count it: We're fusing that // producer -> consumer relationship. diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c21f76f6eb..835924024b 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -24,6 +24,7 @@ limitations under the License. #include <utility> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -101,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, HloInstruction* instr2) { - tensorflow::gtl::FlatSet<HloInstruction*> in_list; + absl::flat_hash_set<HloInstruction*> in_list; for (auto instr : instr1->operands()) { if (!IsProfitableOperand(instr)) { continue; @@ -148,7 +148,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { bool changed = false; RecomputeReachability(); - tensorflow::gtl::FlatSet<HloInstruction*> to_fuse; + absl::flat_hash_set<HloInstruction*> to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, // then filter out instructions that will be no longer fusible because of diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 147776c8c4..b343305554 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -19,6 +19,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" @@ -26,7 +27,7 @@ limitations under the License. namespace xla { using absl::flat_hash_map; -using tensorflow::gtl::FlatSet; +using absl::flat_hash_set; /*static*/ StatusOr<int64> HeapSimulator::MinimumMemoryForModule( @@ -116,9 +117,9 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - flat_hash_map<const BufferValue*, FlatSet<const HloInstruction*>> + flat_hash_map<const BufferValue*, flat_hash_set<const HloInstruction*>> live_buffers; - flat_hash_map<const HloInstruction*, FlatSet<const BufferValue*>> + flat_hash_map<const HloInstruction*, flat_hash_set<const BufferValue*>> used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, @@ -216,7 +217,7 @@ Status HeapSimulator::RunComputation( VLOG(4) << " Removing user " << instruction->name() << " from buffer " << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); - FlatSet<const HloInstruction*>* live_set = &it->second; + flat_hash_set<const HloInstruction*>* live_set = &it->second; live_set->erase(instruction); if (live_set->empty()) { live_buffers.erase(it); @@ -238,7 +239,7 @@ Status HeapSimulator::RunComputation( // that we should assign. // Make sure each buffer get reused at most once. - FlatSet<const BufferValue*> reused_buffers; + flat_hash_set<const BufferValue*> reused_buffers; for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; @@ -326,7 +327,7 @@ Status HeapSimulator::RunComputation( to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { const BufferValue* buffer = buffer_pending.first; - const FlatSet<const HloInstruction*>& pending = buffer_pending.second; + const flat_hash_set<const HloInstruction*>& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; to_free.push_back(buffer); diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index a5bb3f81f7..b0295a6163 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -197,8 +197,8 @@ class HeapSimulator { shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet<const BufferValue*> allocated_buffers_; - tensorflow::gtl::FlatSet<const BufferValue*> freed_buffers_; + absl::flat_hash_set<const BufferValue*> allocated_buffers_; + absl::flat_hash_set<const BufferValue*> freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index b6e1f52cf5..c3da12e273 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -120,7 +121,7 @@ class BufferValueMap { } // Return a set of all the values in the given buffer. - const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer( + const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer( BufferNumber buffer_number) const { return buffers_.at(buffer_number); } @@ -143,7 +144,7 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - tensorflow::gtl::FlatSet<const HloValue*>& old_value_set = + absl::flat_hash_set<const HloValue*>& old_value_set = buffers_.at(old_buffer_number); old_value_set.erase(&value); if (old_value_set.empty()) { @@ -291,7 +292,7 @@ class BufferValueMap { const HloDataflowAnalysis& dataflow_; // A map containing the set of values contained in each buffer. - absl::flat_hash_map<BufferNumber, tensorflow::gtl::FlatSet<const HloValue*>> + absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>> buffers_; // A map indicating which buffer each value is contained in. @@ -351,7 +352,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous( bool HloAliasAnalysis::InstructionBuffersAreDistinct( const HloInstruction* instruction) const { - tensorflow::gtl::FlatSet<const HloBuffer*> buffers_seen; + absl::flat_hash_set<const HloBuffer*> buffers_seen; for (const auto& pair : dataflow_analysis_->GetInstructionValueSet(instruction)) { const HloValueSet& value_set = pair.second; diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 6c11a073b7..9c3aa0e64d 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 257dd5876f..6ef67ab0a8 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -278,10 +278,9 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, namespace { // Helper which builds a post order of the HLO call graph. -void ComputeComputationPostOrder( - HloComputation* computation, - tensorflow::gtl::FlatSet<HloComputation*>* visited, - std::vector<HloComputation*>* post_order) { +void ComputeComputationPostOrder(HloComputation* computation, + absl::flat_hash_set<HloComputation*>* visited, + std::vector<HloComputation*>* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -416,7 +415,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() const { - tensorflow::gtl::FlatSet<HloComputation*> visited; + absl::flat_hash_set<HloComputation*> visited; std::vector<HloComputation*> post_order; // To avoid special handling of this computation, cast away const of diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index af929ac009..d87ab4bda1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -26,6 +26,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index b59c9ba3ed..e602107cbe 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -23,6 +23,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -137,8 +137,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) { // HLO instructions are grouped into equivalency classes by using the // cse_equal predicate defined above. This set holds a representative // instruction for each class. - tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash), - decltype(cse_equal)> + absl::flat_hash_set<HloInstruction*, decltype(&CseHash), + decltype(cse_equal)> representatives(/*N=*/computation->instruction_count() + 1, &CseHash, cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 6a63681996..44cde4a3d2 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include <queue> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -91,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis( bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { - tensorflow::gtl::FlatSet<const HloInstruction*> visited; + absl::flat_hash_set<const HloInstruction*> visited; absl::InlinedVector<const HloInstruction*, 4> stack; stack.push_back(inst); while (!stack.empty()) { @@ -159,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { void HloDataflowAnalysis::DeleteMarkedValues() { #ifndef NDEBUG // Verify that no marked-for-deletion values are in any of the value sets. - tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(), - value_ids_to_delete_.end()); + absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); for (const auto& pair : value_sets_) { const HloInstruction* instruction = pair.first; const InstructionValueSet& instruction_value_set = pair.second; @@ -673,7 +674,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue<HloInstruction*> worklist; - tensorflow::gtl::FlatSet<HloInstruction*> workset; + absl::flat_hash_set<HloInstruction*> workset; auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { if (workset.insert(instruction).second) { worklist.push(instruction); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 159c39d557..6ca1255ede 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -18,6 +18,7 @@ limitations under the License. #include <algorithm> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -217,7 +218,7 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector<HloInstruction*> HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set, + const absl::flat_hash_set<HloInstruction*>& instruction_set, const InstructionOrderMap& instructions_order) { std::vector<HloInstruction*> instructions; instructions.reserve(instruction_set.size()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 8584bc021d..c8d581b746 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -20,13 +20,13 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -110,7 +110,7 @@ class HloDomainMap { // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector<HloInstruction*> MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set, + const absl::flat_hash_set<HloInstruction*>& instruction_set, const InstructionOrderMap& instructions_order); // Populates domain_metadata_id_ that maps each HloInstruction to the unique diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 302807f816..d3c83c15ae 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,11 +20,11 @@ limitations under the License. #include <string> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -42,7 +42,7 @@ class DomainMetadata { // operand/user pathways, without crossing a kDomain instruction of a given // kind. The reach_set can contain kDomain instructions of other kinds, if // two domains of different kind intersect each other. - tensorflow::gtl::FlatSet<HloInstruction*> reach_set; + absl::flat_hash_set<HloInstruction*> reach_set; // The same instructions in reach_set, but purged from kDomain instructions // and ordered according to their computation graph post-order, i.e. @@ -55,8 +55,8 @@ class DomainMetadata { // whose dataflow enters the reach set (domain), while the exit_domains // contains the set of kDomain instructions whose dataflow exit the reach // set. - tensorflow::gtl::FlatSet<HloInstruction*> enter_domains; - tensorflow::gtl::FlatSet<HloInstruction*> exit_domains; + absl::flat_hash_set<HloInstruction*> enter_domains; + absl::flat_hash_set<HloInstruction*> exit_domains; }; virtual ~DomainMetadata() = default; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5d5c9c7e58..0207f9ae3f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/ascii.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" @@ -1433,7 +1433,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { HloInstruction::InstructionVector HloInstruction::unique_operands() const { InstructionVector unique; - tensorflow::gtl::FlatSet<const HloInstruction*> seen; + absl::flat_hash_set<const HloInstruction*> seen; for (HloInstruction* operand : operands()) { if (seen.insert(operand).second) { unique.push_back(operand); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 1c2b2868fd..55314d0ae9 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -111,7 +112,7 @@ class ListScheduler { // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. for (auto* instruction : computation.instructions()) { - tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses; + absl::flat_hash_set<const LogicalBuffer*> instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( [&](const ShapeIndex& /*index*/, @@ -360,7 +361,7 @@ class ListScheduler { std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_; // Set of instructions which have been scheduled. - tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_; + absl::flat_hash_set<const HloInstruction*> scheduled_instructions_; }; int64 SumLogicalBufferSizes( @@ -418,7 +419,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); total_sizes[hlo] = logical_buffer_size; cumulative_total_size += logical_buffer_size; - tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands( + absl::flat_hash_set<const HloInstruction*> unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 9359e9a8be..7527e35c95 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" @@ -328,10 +329,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. - tensorflow::gtl::FlatSet<string> computation_names; - tensorflow::gtl::FlatSet<string> instruction_names; - tensorflow::gtl::FlatSet<int> computation_ids; - tensorflow::gtl::FlatSet<int> instruction_ids; + absl::flat_hash_set<string> computation_names; + absl::flat_hash_set<string> instruction_names; + absl::flat_hash_set<int> computation_ids; + absl::flat_hash_set<int> instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index d83ee71490..fddeb5f0a2 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { std::vector<HloInstruction*> predecessors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet<HloInstruction*> unique; + absl::flat_hash_set<HloInstruction*> unique; // Adds to the unique predecessors list; if the predecessors is a companion // instruction, also add companion instructions; if the predecessors is a @@ -119,7 +119,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { std::vector<HloInstruction*> successors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet<HloInstruction*> unique; + absl::flat_hash_set<HloInstruction*> unique; // Adds to the unique successors list; if the successor is a companion // instruction, also add companion instructions; if the successor is a diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 59fd01cb58..5e004ce78a 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -18,6 +18,7 @@ limitations under the License. #include <functional> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -75,8 +75,8 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal( std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses( const DebugOptions& debug_options) { auto repeated_field = debug_options.xla_disable_hlo_passes(); - tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(), - repeated_field.end()); + absl::flat_hash_set<string> disabled_pass_names(repeated_field.begin(), + repeated_field.end()); if (!disabled_pass_names.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " << absl::StrJoin(disabled_pass_names, ", "); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index abdd9a9212..5ac43808ee 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -21,6 +21,7 @@ limitations under the License. #include <string> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -981,7 +982,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( // rematerialization is essentially a move). If the next rematerialization of // the instruction is also a move then the rematerialization is added to the // blacklist. - tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions; + absl::flat_hash_set<const HloInstruction*> remat_move_instructions; // The map from instructions to their rematerializable status. absl::flat_hash_map<const HloInstruction*, bool> remat_able; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 5a02e3a8bb..70d83c04f0 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -16,6 +16,7 @@ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -122,7 +123,7 @@ class HloRematerialization : public HloModulePass { // Set of computations which have had rematerialization // applied. Rematerialization is only applied once per computation. - tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_; + absl::flat_hash_set<const HloComputation*> rematerialized_computations_; // Count of the total instructions rematerialized. int64 instructions_rematerialized_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 7c5c98f04e..9972eb2077 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -19,6 +19,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -119,7 +120,7 @@ Status HloSchedule::UpdateComputationSchedule( } // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet<int> ids_in_schedule; + absl::flat_hash_set<int> ids_in_schedule; for (int id : sequences_.at(computation->unique_id()).ids()) { InsertOrDie(&ids_in_schedule, id); } @@ -210,7 +211,7 @@ Status HloSchedule::Update() { if (sequences_.size() > nonfusion_computations.size()) { // Schedule contains some computations which have been removed from the // HloModule. Remove them from the schedule as well. - tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids; + absl::flat_hash_set<int64> nonfusion_computations_ids; for (const HloComputation* computation : nonfusion_computations) { nonfusion_computations_ids.insert(computation->unique_id()); } diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 8549487702..59594ab2f0 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,6 +18,7 @@ limitations under the License. #include <algorithm> #include <utility> +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -167,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses( positions_.insert(positions_.end(), positions.begin(), positions.end()); // Gather the computation roots at which this value appears. - tensorflow::gtl::FlatSet<HloInstruction*> root_positions; + absl::flat_hash_set<HloInstruction*> root_positions; for (const HloPosition& position : positions_) { if (position.instruction == position.instruction->parent()->root_instruction()) { diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 7ee789276d..1ebb331977 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace gtl = ::tensorflow::gtl; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 1591256fad..15f0adcaaf 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -26,6 +26,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -504,7 +504,7 @@ class LayoutAssignment : public HloModulePass { // Every copy added to the module by the layout assignment pass is registered // here. - tensorflow::gtl::FlatSet<HloInstruction*> added_copies_; + absl::flat_hash_set<HloInstruction*> added_copies_; // The pointer to the channel layout constraints passed in with the // constructor. If not nullptr, this is an input/output argument. @@ -521,8 +521,7 @@ class LayoutAssignment : public HloModulePass { // The set of HLO instructions which lacked any layout constraint, thus // receiving propagated default layouts. - tensorflow::gtl::FlatSet<const HloInstruction*> - unconstrained_layout_instructions_; + absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 3934d2e493..6223a34b12 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -39,6 +39,7 @@ cc_library( "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index e5370eca56..643ecd0fba 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" -#include <unordered_set> +#include <map> #include "llvm/IR/MDBuilder.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( add_buffers_to_worklist(operand); } - tensorflow::gtl::FlatSet<BufferAllocation::Slice, - BufferAllocation::Slice::Hasher> - buffers; + std::set<BufferAllocation::Slice> buffers; for (const LogicalBuffer* buffer : worklist) { // Skip buffers which cannot be added to the noalias set. if (!assignment.HasAllocation(*buffer) || diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 88cde2d3d9..2b46b3c396 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace llvm_ir { diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 95b1c20663..2ca527bc4c 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/multi_output_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -50,7 +50,7 @@ StatusOr<bool> MultiOutputFusion::Run(HloModule* module) { all_fusion_candidates_.push_back(instruction); std::vector<HloInstruction*> candidates; - tensorflow::gtl::FlatSet<HloInstruction*> candidates_set; + absl::flat_hash_set<HloInstruction*> candidates_set; VLOG(10) << "Looking at instruction: " << instruction->name(); for (auto operand : instruction->operands()) { // Filter out the non-interesting instructions -- they @@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { // Update the fusible list for fusion. Variable new_fusibles keeps // track of the new or changed entries. std::vector<std::pair<HloInstruction*, int64>> new_fusibles; - tensorflow::gtl::FlatSet<HloInstruction*> in_list; + absl::flat_hash_set<HloInstruction*> in_list; auto it = fusion_node.fusibles.begin(); while (it != fusion_node.fusibles.end()) { HloInstruction* instr = it->first; diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 1ac60f1cf4..8909d0f4fe 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -19,9 +19,9 @@ limitations under the License. #include <string> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -69,7 +69,7 @@ class NameUniquer { int64 next_ = 0; // Set of all the identifiers which has been used. - tensorflow::gtl::FlatSet<int64> used_; + absl::flat_hash_set<int64> used_; }; // The string to use to separate the prefix of the name from the uniquing diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 6ccea9d2b5..e379911462 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -22,6 +22,7 @@ limitations under the License. #include <string> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -577,7 +577,7 @@ Status ValidateDotDimensionNumbers( // Check that dimension numbers are unique. auto dims_unique = [](absl::Span<const int64> contracting_dims, absl::Span<const int64> batch_dims) -> bool { - tensorflow::gtl::FlatSet<int64> dim_set; + absl::flat_hash_set<int64> dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 921a984589..56952e3ada 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -147,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - tensorflow::gtl::FlatSet<void*> deallocated_ptrs; + absl::flat_hash_set<void*> deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 78392d3bb2..64ad1dc80e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 2590473c77..9795b2830b 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -16,17 +16,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { using absl::flat_hash_map; +using absl::flat_hash_set; using absl::InlinedVector; -using tensorflow::gtl::FlatSet; // Copies `to_hoist` to the computation containing `while_instr`, hoisting its // operands as needed. All of its transitive operands are expected to be either @@ -35,7 +35,7 @@ using tensorflow::gtl::FlatSet; // them into `hoisted_instructions`. static void CreateLoopInvariantCopy( flat_hash_map<HloInstruction*, HloInstruction*>* hoisted_instructions, - FlatSet<HloInstruction*>* unhoisted_invariant_instructions, + flat_hash_set<HloInstruction*>* unhoisted_invariant_instructions, HloInstruction* while_instr, HloInstruction* to_hoist) { HloComputation* parent_of_while = while_instr->parent(); HloComputation* while_body = while_instr->while_body(); @@ -153,7 +153,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we // hoist an instruction in this set, we move it from // unhoisted_invariant_instructions to hoisted_instructions. - FlatSet<HloInstruction*> unhoisted_invariant_instructions; + flat_hash_set<HloInstruction*> unhoisted_invariant_instructions; // Invariant GTE's axiomatically satisfy the constraints for // unhoisted_invariant_instructions -- they can be legally hoisted, but there diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 07de8492ba..630d71e5ca 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" @@ -114,7 +115,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } - tensorflow::gtl::FlatSet<int64> used_tuple_indices; + absl::flat_hash_set<int64> used_tuple_indices; for (HloComputation* comp : {while_body, while_cond}) { // The HLO verifier ensures that while_input's shape matches while_init's // shape, which we verified above is a tuple. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 06b6330321..8a0ae33042 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2146,11 +2146,11 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 181e5cbe29..bc433eac8f 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -145,7 +146,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet<uint32> key_set; + absl::flat_hash_set<uint32> key_set; for (const float& value : key_arg.data<float>()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second); } @@ -168,7 +169,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet<int32> key_set; + absl::flat_hash_set<int32> key_set; for (const int32& value : key_arg.data<int32>()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second); } |