aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-10-01 13:43:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 13:53:07 -0700
commit3039a4694e22674b502257ae34b0a5b614a631f3 (patch)
tree423fdfa7a2e7dd2740af97accfe848bc97b335d0 /tensorflow/compiler/jit
parent44acd839c57494860666c799afd24360f1df3bed (diff)
[XLA] Migrate from gtl::FlatMap to absl::flat_hash_map
PiperOrigin-RevId: 215272497
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD5
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc22
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_internal.h4
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD1
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.cc3
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc11
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.cc1
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h6
8 files changed, 31 insertions, 22 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 5bf4af1014..29b60d1dbe 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -258,6 +258,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -323,6 +324,7 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -400,6 +402,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -471,6 +474,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -509,6 +513,7 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 25e2e9a7af..e63d4b7792 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -420,15 +421,15 @@ class PredicateFactory {
}
};
- gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>,
- HashSignatureForAndOr>
+ absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
+ HashSignatureForAndOr>
interned_and_or_instances_;
- gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>>
+ absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
interned_not_instances_;
- gtl::FlatMap<SignatureForAndRec, std::unique_ptr<Predicate>>
+ absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
interned_and_rec_instances_;
- gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>,
- HashSignatureForSymbol>
+ absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
+ HashSignatureForSymbol>
interned_symbol_instances_;
};
@@ -572,7 +573,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
- gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
+ absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
+ const;
private:
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
@@ -614,7 +616,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
const Graph& graph_;
- gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
+ absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
PredicateFactory predicate_factory_;
bool vlog_;
};
@@ -977,9 +979,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
return Status::OK();
}
-gtl::FlatMap<TensorId, string, TensorId::Hasher>
+absl::flat_hash_map<TensorId, string, TensorId::Hasher>
DeadnessAnalysisImpl::PredicateMapAsString() const {
- gtl::FlatMap<TensorId, string, TensorId::Hasher> result;
+ absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) {
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h
index 3df2679c62..354782374a 100644
--- a/tensorflow/compiler/jit/deadness_analysis_internal.h
+++ b/tensorflow/compiler/jit/deadness_analysis_internal.h
@@ -16,15 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
namespace deadness_analysis_internal {
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only.
-using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>;
+using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
// Returns a map describing the predicate each Tensor was mapped to. For
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 0839f1cb3d..26cb3af9d6 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -26,6 +26,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],
alwayslink = 1,
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index a85006eb03..cfd27a6510 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -163,7 +164,7 @@ class XlaExecutableClosureStore {
private:
mutex mutex_;
int64 key_counter_ GUARDED_BY(mutex_);
- gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
+ absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
};
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 4f9145b479..2a80c745e3 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
@@ -61,10 +62,10 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
return ids;
}
-gtl::FlatMap<string, std::vector<string>> GetClusterSets(
+absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
const Graph& g, std::vector<string>* cluster_names = nullptr) {
CHECK(cluster_names == nullptr || cluster_names->empty());
- gtl::FlatMap<string, std::vector<string>> cluster_sets;
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets;
for (const auto& p : GetClusters(g)) {
cluster_sets[p.second].push_back(p.first);
}
@@ -566,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph);
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
@@ -586,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph);
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes = {"AssignmentW",
@@ -616,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) {
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::vector<string> cluster_names;
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph, &cluster_names);
ASSERT_EQ(cluster_sets.size(), 2);
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
index 56e35c0059..657bb409db 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -89,7 +89,6 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/util/ptr_util.h"
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index 10ad87e38c..17c0321c1e 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -152,7 +152,7 @@ class XlaCompilationCache : public ResourceBase {
};
mutex compile_cache_mu_;
- gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
+ absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(compile_cache_mu_);
struct CompileStats {
@@ -165,7 +165,7 @@ class XlaCompilationCache : public ResourceBase {
mutex compile_stats_mu_;
// Maps cluster names to compilation statistics for said cluster.
- gtl::FlatMap<string, CompileStats> compile_stats_
+ absl::flat_hash_map<string, CompileStats> compile_stats_
GUARDED_BY(compile_stats_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);