diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-10-01 13:43:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 13:53:07 -0700 |
commit | 3039a4694e22674b502257ae34b0a5b614a631f3 (patch) | |
tree | 423fdfa7a2e7dd2740af97accfe848bc97b335d0 /tensorflow/compiler/jit | |
parent | 44acd839c57494860666c799afd24360f1df3bed (diff) |
[XLA] Migrate from gtl::FlatMap to absl::flat_hash_map
PiperOrigin-RevId: 215272497
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/compiler/jit/deadness_analysis.cc | 22 | ||||
-rw-r--r-- | tensorflow/compiler/jit/deadness_analysis_internal.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/jit/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_ops.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass_test.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/jit/resource_operation_safety_analysis.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_compilation_cache.h | 6 |
8 files changed, 31 insertions, 22 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5bf4af1014..29b60d1dbe 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -258,6 +258,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -323,6 +324,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -400,6 +402,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -471,6 +474,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -509,6 +513,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 25e2e9a7af..e63d4b7792 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/deadness_analysis.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" @@ -420,15 +421,15 @@ class PredicateFactory { } }; - gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>, - HashSignatureForAndOr> + absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>, + HashSignatureForAndOr> interned_and_or_instances_; - gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>> + absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>> interned_not_instances_; - gtl::FlatMap<SignatureForAndRec, std::unique_ptr<Predicate>> + absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>> interned_and_rec_instances_; - gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>, - HashSignatureForSymbol> + absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>, + HashSignatureForSymbol> interned_symbol_instances_; }; @@ -572,7 +573,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; - gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const; + absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString() + const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; @@ -614,7 +616,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status HandleNode(Node* n, std::vector<bool>* should_revisit); const Graph& graph_; - gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_; + absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_; PredicateFactory predicate_factory_; bool vlog_; }; @@ -977,9 +979,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return Status::OK(); } -gtl::FlatMap<TensorId, string, TensorId::Hasher> +absl::flat_hash_map<TensorId, string, TensorId::Hasher> DeadnessAnalysisImpl::PredicateMapAsString() const { - gtl::FlatMap<TensorId, string, TensorId::Hasher> result; + absl::flat_hash_map<TensorId, string, TensorId::Hasher> result; std::vector<TensorId> tensor_ids; for (const auto& kv_pair : predicate_map_) { CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 3df2679c62..354782374a 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ #define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. -using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>; +using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>; Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); // Returns a map describing the predicate each Tensor was mapped to. For diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 0839f1cb3d..26cb3af9d6 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index a85006eb03..cfd27a6510 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/kernels/xla_ops.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -163,7 +164,7 @@ class XlaExecutableClosureStore { private: mutex mutex_; int64 key_counter_ GUARDED_BY(mutex_); - gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_); + absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); }; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 4f9145b479..2a80c745e3 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" @@ -61,10 +62,10 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) { return ids; } -gtl::FlatMap<string, std::vector<string>> GetClusterSets( +absl::flat_hash_map<string, std::vector<string>> GetClusterSets( const Graph& g, std::vector<string>* cluster_names = nullptr) { CHECK(cluster_names == nullptr || cluster_names->empty()); - gtl::FlatMap<string, std::vector<string>> cluster_sets; + absl::flat_hash_map<string, std::vector<string>> cluster_sets; for (const auto& p : GetClusters(g)) { cluster_sets[p.second].push_back(p.first); } @@ -566,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap<string, std::vector<string>> cluster_sets = + absl::flat_hash_map<string, std::vector<string>> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR", @@ -586,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap<string, std::vector<string>> cluster_sets = + absl::flat_hash_map<string, std::vector<string>> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector<string> expected_clustered_nodes = {"AssignmentW", @@ -616,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::vector<string> cluster_names; - gtl::FlatMap<string, std::vector<string>> cluster_sets = + absl::flat_hash_map<string, std::vector<string>> cluster_sets = GetClusterSets(*graph, &cluster_names); ASSERT_EQ(cluster_sets.size(), 2); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 56e35c0059..657bb409db 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -89,7 +89,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/util/ptr_util.h" diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 10ad87e38c..17c0321c1e 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -152,7 +152,7 @@ class XlaCompilationCache : public ResourceBase { }; mutex compile_cache_mu_; - gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ + absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ GUARDED_BY(compile_cache_mu_); struct CompileStats { @@ -165,7 +165,7 @@ class XlaCompilationCache : public ResourceBase { mutex compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. - gtl::FlatMap<string, CompileStats> compile_stats_ + absl::flat_hash_map<string, CompileStats> compile_stats_ GUARDED_BY(compile_stats_mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); |