aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-10-01 19:42:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 19:46:29 -0700
commit991f06fd50fc73285ce415d57f720994c2b2e861 (patch)
tree0b87402386aad22ec958f171bfd57f9c7c3e8571 /tensorflow/compiler/jit
parentbeede8525be5386451bf0098992c37416d1864db (diff)
[XLA] Migrate from gtl::FlatSet to absl::flat_hash_set
PiperOrigin-RevId: 215324035
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc10
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc7
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc10
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc6
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc7
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.cc4
7 files changed, 25 insertions, 21 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index f20270931f..661b444a42 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -325,6 +325,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -407,6 +408,7 @@ cc_library(
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index e63d4b7792..e0b9932d80 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -16,11 +16,11 @@ limitations under the License.
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
// ALGORITHM OVERVIEW
@@ -298,7 +298,7 @@ class SymbolPredicate : public Predicate {
template <typename FunctionTy>
/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
- gtl::FlatSet<Predicate*> visited;
+ absl::flat_hash_set<Predicate*> visited;
std::vector<Predicate*> stack;
stack.push_back(p);
@@ -467,7 +467,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
Predicate::Kind other_pred_kind =
is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
- gtl::FlatSet<Predicate*> simplified_ops_set;
+ absl::flat_hash_set<Predicate*> simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
// Simplify A&A => A and A|A => A.
@@ -492,7 +492,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
}
// Simplify "A&~A=>False" and "A|~A=>True".
- gtl::FlatSet<Predicate*> negated_ops;
+ absl::flat_hash_set<Predicate*> negated_ops;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kNot) {
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
@@ -512,7 +512,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
//
// First find any predicates contained in all subops.
std::vector<Predicate*> common_inner_operands;
- gtl::FlatSet<Predicate*> common_inner_operands_set;
+ absl::flat_hash_set<Predicate*> common_inner_operands_set;
for (Predicate* op : simplified_ops) {
if (op->kind() != other_pred_kind) {
common_inner_operands.clear();
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index d165341f21..da27f837e8 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/public/session_options.h"
@@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) {
namespace {
bool AreAllParentsGuaranteedConst(
- const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
+ const Node& n,
+ const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
if (n.type_string() == "GuaranteeConst") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
@@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst(
void MarkGuaranteedConstants(
const Graph& graph,
const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
- gtl::FlatSet<const Node*> guaranteed_const_nodes;
+ absl::flat_hash_set<const Node*> guaranteed_const_nodes;
std::vector<const Node*> srcs;
srcs.reserve(src_arg_pairs.size());
for (const auto& src_arg : src_arg_pairs) {
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 755c364c62..2ce6fa73fc 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -15,13 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) {
}
// Adds the control inputs of `node` to `*deps`.
-void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+void AddControlInputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
for (const Edge* edge : node.in_edges()) {
if (edge->IsControlEdge()) {
deps->insert(edge->src());
@@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
}
// Adds the control outputs of `node` to `*deps`.
-void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+void AddControlOutputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
for (const Edge* edge : node.out_edges()) {
if (edge->IsControlEdge()) {
deps->insert(edge->dst());
@@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Data and control inputs to the new XlaLaunch node.
std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
- gtl::FlatSet<Node*> control_inputs;
+ absl::flat_hash_set<Node*> control_inputs;
DataTypeVector arg_types(num_args);
AddControlInputs(*launch, &control_inputs);
@@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Outputs.
const int num_outputs = launch->output_types().size();
- gtl::FlatSet<Node*> control_outputs;
+ absl::flat_hash_set<Node*> control_outputs;
std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
DataTypeVector output_types(num_outputs);
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 133d982360..4f0c370e65 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
@@ -371,7 +371,7 @@ bool IsXlaFusable(const NodeDef& node) {
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
- OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) {
+ OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
@@ -849,7 +849,7 @@ Status MarkForCompilationPass::RunImpl(
Graph* graph = options.graph->get();
OrderedNodeSet compilation_candidates;
- gtl::FlatSet<Node*> isolated_nodes;
+ absl::flat_hash_set<Node*> isolated_nodes;
TF_RETURN_IF_ERROR(FindCompilationCandidates(
*graph, options.flib_def,
(options.session_options != nullptr) ? options.session_options->env
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 10fc9e85d9..b1f9e9088f 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -15,17 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace tensorflow {
namespace {
-Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
+Status FindNodesToDecluster(const Graph& graph,
+ absl::flat_hash_set<Node*>* result,
absl::Span<Node* const> post_order) {
// Find nodes that have at least one user outside their cluster that expects
// hostmem output. These nodes should be cloned to outside the cluster to
@@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/NotBackedge);
- gtl::FlatSet<Node*> nodes_to_partially_decluster;
+ absl::flat_hash_set<Node*> nodes_to_partially_decluster;
TF_RETURN_IF_ERROR(
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
index 657bb409db..e039d46ec8 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -82,6 +82,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
@@ -89,7 +90,6 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/util/ptr_util.h"
@@ -176,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) {
// point.
class ResourceOpSet {
private:
- using Impl = gtl::FlatSet<ResourceOp>;
+ using Impl = absl::flat_hash_set<ResourceOp>;
public:
ResourceOpSet() = default;