diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-10-01 19:42:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 19:46:29 -0700 |
commit | 991f06fd50fc73285ce415d57f720994c2b2e861 (patch) | |
tree | 0b87402386aad22ec958f171bfd57f9c7c3e8571 /tensorflow/compiler/jit | |
parent | beede8525be5386451bf0098992c37416d1864db (diff) |
[XLA] Migrate from gtl::FlatSet to absl::flat_hash_set
PiperOrigin-RevId: 215324035
Diffstat (limited to 'tensorflow/compiler/jit')
7 files changed, 25 insertions, 21 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f20270931f..661b444a42 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -325,6 +325,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -407,6 +408,7 @@ cc_library( "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index e63d4b7792..e0b9932d80 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -16,11 +16,11 @@ limitations under the License. #include "tensorflow/compiler/jit/deadness_analysis.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW @@ -298,7 +298,7 @@ class SymbolPredicate : public Predicate { template <typename FunctionTy> /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { - gtl::FlatSet<Predicate*> visited; + absl::flat_hash_set<Predicate*> visited; std::vector<Predicate*> stack; stack.push_back(p); @@ -467,7 +467,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; Predicate::Kind other_pred_kind = is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; - gtl::FlatSet<Predicate*> simplified_ops_set; + absl::flat_hash_set<Predicate*> simplified_ops_set; std::vector<Predicate*> simplified_ops; for (Predicate* op : operands) { // Simplify A&A => A and A|A => A. @@ -492,7 +492,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( } // Simplify "A&~A=>False" and "A|~A=>True". - gtl::FlatSet<Predicate*> negated_ops; + absl::flat_hash_set<Predicate*> negated_ops; for (Predicate* op : simplified_ops) { if (op->kind() == Predicate::Kind::kNot) { negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand()); @@ -512,7 +512,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( // // First find any predicates contained in all subops. std::vector<Predicate*> common_inner_operands; - gtl::FlatSet<Predicate*> common_inner_operands_set; + absl::flat_hash_set<Predicate*> common_inner_operands_set; for (Predicate* op : simplified_ops) { if (op->kind() != other_pred_kind) { common_inner_operands.clear(); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index d165341f21..da27f837e8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/public/session_options.h" @@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) { namespace { bool AreAllParentsGuaranteedConst( - const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) { + const Node& n, + const absl::flat_hash_set<const Node*>& runtime_const_nodes) { if (n.type_string() == "GuaranteeConst") { // If the current node is itself a cast-to-const, no need // to look at the incoming edges. @@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst( void MarkGuaranteedConstants( const Graph& graph, const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) { - gtl::FlatSet<const Node*> guaranteed_const_nodes; + absl::flat_hash_set<const Node*> guaranteed_const_nodes; std::vector<const Node*> srcs; srcs.reserve(src_arg_pairs.size()); for (const auto& src_arg : src_arg_pairs) { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 755c364c62..2ce6fa73fc 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) { } // Adds the control inputs of `node` to `*deps`. -void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) { +void AddControlInputs(const Node& node, absl::flat_hash_set<Node*>* deps) { for (const Edge* edge : node.in_edges()) { if (edge->IsControlEdge()) { deps->insert(edge->src()); @@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) { } // Adds the control outputs of `node` to `*deps`. -void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) { +void AddControlOutputs(const Node& node, absl::flat_hash_set<Node*>* deps) { for (const Edge* edge : node.out_edges()) { if (edge->IsControlEdge()) { deps->insert(edge->dst()); @@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, // Data and control inputs to the new XlaLaunch node. std::vector<std::pair<Node*, int>> data_inputs(num_inputs); - gtl::FlatSet<Node*> control_inputs; + absl::flat_hash_set<Node*> control_inputs; DataTypeVector arg_types(num_args); AddControlInputs(*launch, &control_inputs); @@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, // Outputs. const int num_outputs = launch->output_types().size(); - gtl::FlatSet<Node*> control_outputs; + absl::flat_hash_set<Node*> control_outputs; std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs); DataTypeVector output_types(num_outputs); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 133d982360..4f0c370e65 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -371,7 +371,7 @@ bool IsXlaFusable(const NodeDef& node) { Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn, - OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) { + OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) { OptimizerOptions opts; std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -849,7 +849,7 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); OrderedNodeSet compilation_candidates; - gtl::FlatSet<Node*> isolated_nodes; + absl::flat_hash_set<Node*> isolated_nodes; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 10fc9e85d9..b1f9e9088f 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -15,17 +15,18 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace tensorflow { namespace { -Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result, +Status FindNodesToDecluster(const Graph& graph, + absl::flat_hash_set<Node*>* result, absl::Span<Node* const> post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to @@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/NotBackedge); - gtl::FlatSet<Node*> nodes_to_partially_decluster; + absl::flat_hash_set<Node*> nodes_to_partially_decluster; TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 657bb409db..e039d46ec8 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" @@ -89,7 +90,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/util/ptr_util.h" @@ -176,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) { // point. class ResourceOpSet { private: - using Impl = gtl::FlatSet<ResourceOp>; + using Impl = absl::flat_hash_set<ResourceOp>; public: ResourceOpSet() = default; |